import copy
import re
from pathlib import Path
from typing import List

import deepl
import pandas as pd

from align.config.app_config import DatasetName, LanguageCode
from align.utils import load_jsonl_as_dataframe, save_as_json


def translate_MT_bench_all_languages(
    dataset_name: DatasetName, deepl_api_key: str, target_lang_codes: List[LanguageCode], dataset_path: Path
):
    assert dataset_name == DatasetName.mt_bench, "Only mt_bench supported for translation!"
    data = load_jsonl_as_dataframe(data_file_path=dataset_path)
    for target_lang_code in target_lang_codes:
        data = translate_single_language(
            data=data,
            text_column_name="turns",
            deepl_api_key=deepl_api_key,
            target_lang_code=target_lang_code,
        )
        data = translate_single_language(
            data=data,
            text_column_name="reference",
            deepl_api_key=deepl_api_key,
            target_lang_code=target_lang_code,
        )
    _save_translations(data=data, dataset_path=dataset_path, target_lang_codes=target_lang_codes)
    _save_translated_questions(data=data, dataset_path=dataset_path, target_lang_codes=target_lang_codes)


def translate_MT_bench_ref_answers_languages(
    dataset_name: DatasetName, deepl_api_key: str, target_lang_codes: List[LanguageCode], dataset_path: Path
):
    assert dataset_name == DatasetName.mt_bench_ref_answers, "Only mt_bench_ref_answers supported for translation!"
    data = load_jsonl_as_dataframe(data_file_path=dataset_path)
    translator = deepl.Translator(deepl_api_key)
    for target_lang_code in target_lang_codes:
        entries = data.copy().to_dict(orient="records")
        new_entries = []
        for entry in entries:
            new_entry = copy.deepcopy(entry)
            for choice_idx, choice in enumerate(entry["choices"]):
                turns = choice["turns"]
                translated_turns = translate_texts(
                    translator=translator, texts=turns, target_lang_code=target_lang_code.value
                )
                new_entry["choices"][choice_idx]["turns"] = translated_turns
                new_entry["question_id"] = f"{new_entry['question_id']}_{target_lang_code}"
            new_entries.append(new_entry)

        out_file_path = (
            dataset_path.parent.parent.parent
            / f"mt_bench_{target_lang_code}"
            / dataset_path.parent.name
            / dataset_path.name
        )
        out_file_path.parent.mkdir(parents=True, exist_ok=True)
        save_as_json(entries=new_entries, out_path=out_file_path)


def translate_MTbench_judge_prompts(
    dataset_name: DatasetName, deepl_api_key: str, target_lang_codes: List[LanguageCode], dataset_path: Path
):
    assert dataset_name == DatasetName.mt_bench_judge_prompts, "Only mt_bench_judge_prompts supported for translation!"
    data = load_jsonl_as_dataframe(data_file_path=dataset_path)
    translator = deepl.Translator(deepl_api_key)

    for target_lang_code in target_lang_codes:
        new_data = data.copy(deep=True)
        for idx, prompt_template in enumerate(data.prompt_template.tolist()):
            placeholders = re.findall(r"{(.*?)}", prompt_template)
            split_template = re.split(r"{.*?}", prompt_template)
            translated_split_template = [
                # further splitting needed, otherwise incomplete translation by DeepL
                "\n\n###".join(
                    translate_texts(
                        translator=translator, texts=splitter.split("\n\n###"), target_lang_code=target_lang_code.value
                    )
                )
                for splitter in split_template
            ]

            reconstructed_template = (
                "".join([a + f"{{{b}}}" for a, b in zip(translated_split_template, placeholders)])
                + translated_split_template[-1]
            )
            new_data.prompt_template.iloc[idx] = reconstructed_template

        system_prompts = translate_texts(
            translator=translator, texts=data.system_prompt.tolist(), target_lang_code=target_lang_code.value
        )
        new_data.system_prompt = system_prompts
        out_file_path = dataset_path.parent / f"mt_bench_{target_lang_code}" / dataset_path.name
        save_as_json(entries=new_data.to_dict(orient="records"), out_path=out_file_path)


def _save_translated_questions(data: pd.DataFrame, target_lang_codes: List[LanguageCode], dataset_path: Path):
    for language_code in target_lang_codes:
        column_names_to_exclude = []
        for column_prefix in ["turns", "reference"]:
            data.rename(columns={column_prefix: f"{column_prefix}_{LanguageCode.EN.value}"}, inplace=True)
            column_names_to_exclude.extend(
                [f"{column_prefix}_{lang_code}" for lang_code in LanguageCode if lang_code != language_code]
            )
        new_data: pd.DataFrame = data.loc[:, ~data.columns.isin(column_names_to_exclude)].copy()
        out_file_path = dataset_path.parent.parent.parent / f"mt_bench_{language_code}" / "question.jsonl"
        out_file_path.parent.mkdir(parents=True, exist_ok=True)
        for column_prefix in ["turns", "reference"]:
            new_data.rename(columns={f"{column_prefix}_{language_code}": column_prefix}, inplace=True)
        new_data["question_id"] = new_data["question_id"].astype(str) + f"_{language_code}"
        save_as_json(entries=new_data.to_dict(orient="records"), out_path=out_file_path)


def translate_LIMA_all_languages(
    dataset_name: DatasetName, deepl_api_key: str, target_lang_codes: List[LanguageCode], dataset_path: Path
):
    assert dataset_name == DatasetName.lima, "Only LIMA supported for translation!"
    lima_data = load_jsonl_as_dataframe(data_file_path=dataset_path)
    for target_lang_code in target_lang_codes:
        lima_data = translate_single_language(
            data=lima_data,
            text_column_name="conversations",
            deepl_api_key=deepl_api_key,
            target_lang_code=target_lang_code,
        )
    _save_translations(data=lima_data, dataset_path=dataset_path, target_lang_codes=target_lang_codes)


def translate_single_language(
    data: pd.DataFrame, text_column_name: str, deepl_api_key: str, target_lang_code: LanguageCode
):
    assert text_column_name in data.columns, f"Unknown column name {text_column_name} in \n {data.columns}"
    translator = deepl.Translator(deepl_api_key)
    translations = []
    for text in data[text_column_name]:
        if isinstance(text, str) or isinstance(text, list):
            translated_texts = translate_texts(
                translator=translator, texts=text, target_lang_code=target_lang_code.value
            )
        else:
            translated_texts = []
        translations.append(translated_texts)
    data[f"{text_column_name}_{target_lang_code.value}"] = translations
    return data


def _save_translations(data: pd.DataFrame, dataset_path: Path, target_lang_codes: List[LanguageCode]):
    out_file_path = dataset_path.parent / (
        dataset_path.stem + f"_EN_{'_'.join([code.value for code in target_lang_codes])}.jsonl"
    )
    save_as_json(entries=data.to_dict(orient="records"), out_path=out_file_path)
    data.to_csv(path_or_buf=out_file_path.with_suffix(".csv"), index=False)


def translate_texts(translator: deepl.Translator, texts: List[str], target_lang_code: str) -> List[str]:
    return [
        translation.text
        for translation in translator.translate_text(
            texts,
            target_lang=target_lang_code,
        )
    ]
