from pathlib import Path
from typing import Dict, List, Union

import numpy as np
import pandas as pd

from align.utils import load_jsonl_as_dataframe


def format_bactrianx_to_fastchat_and_create_val(data_folder_path: Union[Path, str]):
    data_folder_path = Path(data_folder_path)
    all_data = []
    file_paths = list(data_folder_path.glob("bactrian*.jsonl"))
    all_data = [sort_by_mixed_id(load_jsonl_as_dataframe(data_file_path)) for data_file_path in file_paths]
    # assert same IDs across languages
    datasets_ids = [d.id.unique().tolist() for d in all_data]
    unique_ids_across_datasets = pd.concat(all_data).id.unique().tolist()
    for dataset_ids in datasets_ids:
        assert set(unique_ids_across_datasets) == set(dataset_ids)

    all_new_data = []
    for data_file_path, data in zip(file_paths, all_data):
        entries = data.to_dict(orient="records")
        language = data_file_path.stem.split("_")[-1]
        new_entries = []
        # FastChat format:
        # {
        #     "id": "id_0",
        #     "conversations": [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}],
        #     "language": "DE",
        # }
        for idx, entry in enumerate(entries):
            new_entry = {}
            new_entry["id"] = entry["id"] + f"_{language}"
            new_entry["language"] = language
            instruction = entry["instruction"]
            input = entry["input"]
            if input != "" and input is not None:
                human = f"{instruction}\n\n {input}"
            else:
                human = f"{instruction}\n\n"
            gpt = entry["output"]
            new_entry["conversations"] = [{"from": "human", "value": human}, {"from": "gpt", "value": gpt}]
            new_entries.append(new_entry)
        all_new_data.append(pd.DataFrame(new_entries))

    # shuffle a single dataset index and reorder by suffled index across languages
    reorder_index = all_new_data[0].index.tolist()
    np.random.shuffle(reorder_index)

    # split into train and val
    datasets = {}
    for data_file_path, new_data in zip(file_paths, all_new_data):
        # shuffle
        new_data = new_data.reindex(reorder_index)
        # split
        num_val = int(len(new_data) * 0.05)
        train = new_data[num_val:]
        val = new_data[:num_val]
        assert len(val) + len(train) == len(new_data)
        language = data_file_path.stem.split("_")[-1]
        datasets[language] = {}
        datasets[language]["val"] = val
        datasets[language]["train"] = train
    all_languages_id = "EN_DE_FR_IT_ES"
    languages = all_languages_id.split("_")
    num_languages = len(languages)

    # ENDEFRITES
    datasets[all_languages_id] = dict(
        val=pd.concat([partitions["val"] for partitions in datasets.values()]),
        train=pd.concat([partitions["train"] for partitions in datasets.values()]),
    )

    # ENDEFRITES_sampled
    all_languages_sampled_id = all_languages_id + "_sampled"
    datasets[all_languages_sampled_id] = {}
    for partition in ["train", "val"]:
        print(datasets[all_languages_id][partition].index.value_counts())
        dataset = datasets[all_languages_id][partition]
        indices = ((pd.Series(data.index) + np.random.randint(0, num_languages)) % num_languages).to_list()
        dataset_sampled = dataset.groupby(dataset.index).apply(lambda d: d.iloc[indices.pop()])
        print(dataset_sampled.index.value_counts())
        datasets[all_languages_sampled_id][partition] = dataset_sampled

    stats = []
    for lang_postfix, dataset in datasets.items():
        for partition, partition_data in dataset.items():
            stats.append(
                {
                    "lang_postfix": lang_postfix,
                    "partition": partition,
                    "#samples": len(partition_data),
                    "#ids": len(partition_data.id.unique()),
                }
            )
    print(pd.DataFrame(stats))
    save_dir = data_folder_path / "in_fastchat_format"
    save_dir.mkdir(parents=True, exist_ok=True)
    for lang_postfix, dataset in datasets.items():
        for partition, partition_data in dataset.items():
            # sort and save
            sort_by_mixed_id(partition_data).to_json(
                save_dir / (f"{partition}_bactrianx_{lang_postfix}.jsonl"), lines=True, index=False, orient="records"
            )


def sort_by_mixed_id(partition_data: pd.DataFrame) -> pd.DataFrame:
    partition_data[["model", "numeric_id"]] = partition_data.id.str.split("-", expand=True).copy()
    partition_data.numeric_id = partition_data.numeric_id.str.split("_", expand=True)[0].astype(int)
    partition_data = partition_data.sort_values(by=["model", "numeric_id"])
    partition_data = partition_data.drop(["model", "numeric_id"], axis=1)
    return partition_data


def format_lima_to_fastchat(
    data: List[Dict], end_conversation_with_bot_response: bool, extra_fileds_to_include: List[str] = []
) -> List[Dict]:
    assert all(
        [field in data[0] for field in extra_fileds_to_include]
    ), f"Not all extra fields ({extra_fileds_to_include}) contained in data {data[0].keys()}"

    fastchat_conversations = []
    for id, entry in enumerate(data):
        id_str = entry["id"] if "id" in entry else f"id_{id}"

        fastchat_entry = {"id": id_str, "conversations": []}

        for idx, utterance in enumerate(entry["conversations"]):
            fastchat_entry["conversations"].append({"from": ("human" if idx % 2 == 0 else "gpt"), "value": utterance})

        if end_conversation_with_bot_response:
            if fastchat_entry["conversations"][-1]["from"] == "human":
                del fastchat_entry["conversations"][-1]
                if len(fastchat_entry["conversations"]) < 2:
                    continue
        for field in extra_fileds_to_include:
            fastchat_entry[field] = entry[field]
        fastchat_conversations.append(fastchat_entry)
    return fastchat_conversations
