import csv
import logging
from pathlib import Path
from typing import List

import fasttext
import numpy as np
import pandas as pd
from align.config.app_config import LanguageCode
from align.exploration.format_to_fastchat import format_lima_to_fastchat
from align.exploration.language_detection import _detect_language, get_fasttext_model
from align.preprocess.near_duplicates import get_near_duplicates
from align.utils import load_jsonl_as_dataframe, save_as_json


def create_multi_lingual_stackexchange_val_sets(
    dataset_path: Path,
    train_dataset_path: Path,
    max_num_samples_per_source: int,
):
    lang_detect_model = get_fasttext_model()
    dataset = []
    train_data = load_jsonl_as_dataframe(data_file_path=train_dataset_path)
    train_data_entries = train_data[train_data.source == "stackexchange"].conversations.apply(lambda x: " ".join(x))
    file_paths = dataset_path.glob("*.jsonl")
    for data_file_path in file_paths:
        data = pd.read_json(path_or_buf=data_file_path.__str__(), orient="records", lines=True)
        data["source"] = data_file_path.stem.split("_")[0]
        data = _get_best_answer(data)
        data = _filter_by_phrases(data)
        data = _filter_by_length(data)
        data = _filter_by_consistent_language(data=data, lang_detect_model=lang_detect_model)
        dataset.append(data)
    dataset = pd.concat(dataset)
    dataset = _filter_by_train_data(data=dataset, train_data_entries=train_data_entries)
    dataset = (
        dataset.sort_values(["source", "score"], ascending=False).groupby("source").head(max_num_samples_per_source)
    )
    logging.info(f"Got {len(dataset)} val samples.")

    dataset = dataset.reindex(["source", "title", "answer", "body", "score", "tags", "language"], axis=1)
    out_file_path = dataset_path / "val_dataset" / "val_dataset_EN.csv"
    out_file_path.parent.mkdir(parents=True, exist_ok=True)
    dataset.to_csv(path_or_buf=out_file_path, index=False, quoting=csv.QUOTE_ALL)


def _get_best_answer(data: pd.DataFrame) -> pd.DataFrame:
    data["answer"] = data.answers.apply(lambda answers: answers[0]["text"])
    data.drop("answers", axis=1, inplace=True)
    return data


def _filter_by_phrases(data: pd.DataFrame) -> pd.DataFrame:
    return data.drop(
        data[
            data.answer.str.contains("I ")
            | data.answer.str.lower().str.contains(" my ")
            | data.answer.str.lower().str.contains("as mentioned")
            | data.answer.str.lower().str.contains("stack exchange")
            | data.answer.str.lower().str.contains("referenced")
            | data.answer.str.lower().str.contains("resources")
            | data.answer.str.lower().str.contains("sources")
            | data.answer.str.lower().str.contains("papers")
            | data.answer.str.lower().str.contains("image")
            | data.answer.str.lower().str.contains("figure")
            | data.answer.str.lower().str.contains("diagram")
        ].index
    )


def _filter_by_length(data: pd.DataFrame) -> pd.DataFrame:
    data["qa"] = data.title + data.answer
    data = data.drop(data[(data.qa.str.len() > 4096) | (data.qa.str.len() < 1200)].index)
    data.drop("qa", axis=1, inplace=True)
    return data


def _filter_by_consistent_language(data: pd.DataFrame, lang_detect_model: fasttext.FastText) -> pd.DataFrame:
    data["title_language"] = _detect_language(
        texts=(text.replace("\n", " ") for text in data.title),
        lang_detect_model=lang_detect_model,
    )
    data["answer_language"] = _detect_language(
        texts=(text.replace("\n", " ") for text in data.answer),
        lang_detect_model=lang_detect_model,
    )
    data = data.drop(data[data.answer_language != data.title_language].index)
    data["language"] = data.title_language
    data.drop(["title_language", "answer_language"], axis=1, inplace=True)
    data = data.drop(data[data.language != "en"].index)
    return data


def _filter_by_train_data(data: pd.DataFrame, train_data_entries: List[str]) -> pd.DataFrame:
    ta_val_data = (data.title + " " + data.answer).to_list()
    ba_val_data = (data.body + " " + data.answer).to_list()
    is_near_duplicate = _get_is_near_duplicate(
        val_data=ta_val_data, train_data_entries=train_data_entries
    ) | _get_is_near_duplicate(val_data=ba_val_data, train_data_entries=train_data_entries)

    data["is_near_duplicate"] = is_near_duplicate.to_list()
    logging.info(f"Num duplicates found per source:\n{data.groupby('source')['is_near_duplicate'].sum()}")
    data.drop(data[data.is_near_duplicate].index, inplace=True)

    assert data.is_near_duplicate.sum() == 0
    data.drop("is_near_duplicate", axis=1, inplace=True)
    return data


def _get_is_near_duplicate(val_data: List[str], train_data_entries: List[str]) -> pd.Series:
    logging.info("Detect near duplicates...")
    near_duplicates = get_near_duplicates(
        docs=val_data, docs_label_prefix="val", queries=train_data_entries, queries_label_prefix="train"
    )
    return near_duplicates[near_duplicates.index.str.contains("val\d")].sum(axis=1) > 0


def create_multi_lingual_lima_datasets_for_fastchat(
    dataset_path: Path,
    languages_per_dataset: List[List[LanguageCode]],
    column_prefix: str,
):
    lang_key = "language"
    data = load_jsonl_as_dataframe(data_file_path=dataset_path)
    data.rename(columns={column_prefix: f"{column_prefix}_{LanguageCode.EN.value}"}, inplace=True)
    for language_codes in languages_per_dataset:
        datasets = []
        for language_code in language_codes:
            column_name = f"{column_prefix}_{language_code}"
            assert column_name in data.columns, f"Column name {column_name} not in data columns: {data.columns}"
            column_names_to_exclude = [
                f"{column_prefix}_{lang_code}" for lang_code in LanguageCode if lang_code != language_code
            ]
            new_data = data.loc[:, ~data.columns.isin(column_names_to_exclude)].copy()
            # for fastchat only a single "conversations" attribute is valid
            new_data.rename(columns={column_name: column_prefix}, inplace=True)
            new_data[lang_key] = language_code.value
            new_data["id"] = [f"id_{id}_{lang}" for id, lang in zip(list(new_data.index), new_data[lang_key])]
            datasets.append(new_data)
        dataset = pd.concat(datasets)

        dataset_variants = []
        num_languages = len(language_codes)
        if num_languages > 1:
            # sample examples by language to get equal sized datasets per "languages_per_dataset"
            # thereby each of the examples should occour in the new dataset, but in an arbitrary language
            # np.random.randint(0, num_languages) is added to
            # not be unfair to data cut-off i.e. cases where `(len(data) % num_languages) > 0`
            indices = ((pd.Series(data.index) + np.random.randint(0, num_languages)) % num_languages).to_list()
            # note: the index is not reset, hence each index value exists for each language and we can group by it
            # note: we cannot use d.sample(n=1), as we have too few examples in total (law of big numbers)
            dataset_variant = dataset.groupby(dataset.index).apply(lambda d: d.iloc[indices.pop()])
            logging.info(dataset_variant[lang_key].value_counts())
            dataset_variants.append((dataset_variant, "_sampled"))
        dataset_variants.append((dataset, ""))

        for dataset_variant, full_data_postfix in dataset_variants:
            fastchat_entries = format_lima_to_fastchat(
                data=dataset_variant.to_dict(orient="records"),
                end_conversation_with_bot_response=True,
                extra_fileds_to_include=[lang_key],
            )
            out_file_path = Path(
                dataset_path.parent
                / "fastchat_dataset"
                / (dataset_path.stem + "_fastchat_format_" + "_".join(language_codes) + full_data_postfix + ".jsonl")
            )
            out_file_path.parent.mkdir(parents=True, exist_ok=True)
            save_as_json(entries=fastchat_entries, out_path=out_file_path)
