import logging
from collections import Counter
from functools import wraps
from pathlib import Path
from typing import Any, Dict, Iterable, List

import fasttext
import pandas as pd
import wget
from huggingface_hub import snapshot_download
from tqdm import tqdm

from align.config.app_config import AppConfig, DatasetName
from align.utils import (
    PROJECT_ROOT,
    load_json,
    load_json_as_dataframe,
    load_jsonl_as_dataframe,
    save_as_json,
    stream_jsonl_store_err_lines,
)


def detect_language_all(config: AppConfig):
    result_per_dataset = {}
    for dataset_name in DatasetName:
        for result in detect_language(
            dataset_name=dataset_name, dataset_path=config.data.dataset_name_to_path[dataset_name.value]
        ):
            result_per_dataset.update(result)
    data = pd.DataFrame.from_dict(result_per_dataset).loc[["en", "de", "fr", "it", "es"]]
    print(data.T.to_markdown())


def detect_language(dataset_name: DatasetName, dataset_path: Path):
    lang_detect_model = get_fasttext_model()
    explore_function = None
    if dataset_name == DatasetName.lima:
        explore_function = _detect_langauge_lima
    elif dataset_name == DatasetName.stackexchange:
        explore_function = _detect_langauge_stackexchange
    elif dataset_name == DatasetName.sharegpt:
        explore_function = _detect_langauge_sharegpt
    elif dataset_name == DatasetName.fastchat_oasst:
        explore_function = _detect_langauge_oasst
    elif dataset_name == DatasetName.oig:
        explore_function = _detect_langauge_oig
    elif dataset_name == DatasetName.dolly:
        explore_function = _detect_langauge_dolly
    elif dataset_name == DatasetName.hh_rlhf:
        explore_function = _detect_langauge_hh_rlhf
    elif dataset_name == DatasetName.alpaca:
        explore_function = _detect_langauge_alpaca
    else:
        return []

    result = load_if_result_available(dataset_path)(explore_function)(lang_detect_model)
    if isinstance(result, dict):
        result = [{dataset_name.value: result}]
    return result


def get_fasttext_model() -> fasttext.FastText:
    model_filename = "lid.176.bin"
    model_dir_path = PROJECT_ROOT / "models"

    model_path = model_dir_path.joinpath(model_filename)
    if not model_path.exists():
        if not model_dir_path.exists():
            model_dir_path.mkdir(parents=True, exist_ok=True)
        model_url = f"https://dl.fbaipublicfiles.com/fasttext/supervised-models/{model_filename}"
        logging.info(f"Downloading model from {model_url}")
        wget.download(url=model_url, out=model_dir_path.__str__())
    model = fasttext.load_model(model_path.__str__())
    return model


def load_if_result_available(data_file_path: Path, enable: bool = True):
    def decorator(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            result_data_file_path = _get_lang_detect_file_path(data_file_path)
            if not enable or not result_data_file_path.exists() or result_data_file_path.is_dir():
                kwargs.update(dict(data_file_path=data_file_path))
                return function(*args, **kwargs)
            else:
                return load_json(result_data_file_path)

        return wrapper

    return decorator


def _detect_langauge_alpaca(lang_detect_model: fasttext.FastText, data_file_path: Path):
    df = load_jsonl_as_dataframe(data_file_path=data_file_path)
    language_predictions = []
    conversations = pd.DataFrame(
        (df.text.astype(str) + df.instruction.astype(str) + df.input.astype(str))
        .astype(str)
        .str.replace("\n", " ")
        .str.replace("###", " ")
    )[0].tolist()
    language_predictions = _detect_language(
        texts=conversations,
        lang_detect_model=lang_detect_model,
    )
    logging.info("LIMA:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _detect_langauge_lima(lang_detect_model: fasttext.FastText, data_file_path: Path):
    lima_data = load_jsonl_as_dataframe(data_file_path=data_file_path)
    language_predictions = []
    multi_turn_conversations = (" ".join(conversations).replace("\n", " ") for conversations in lima_data.conversations)
    language_predictions = _detect_language(
        texts=multi_turn_conversations,
        lang_detect_model=lang_detect_model,
    )
    logging.info("LIMA:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _annotate_lima_dataset(data_file_path: Path):
    lang_detect_model = get_fasttext_model()
    lima_data = load_jsonl_as_dataframe(data_file_path=data_file_path)
    multi_turn_conversations = []
    num_turns_per_conv = []
    languages_per_turn_in_conv = []
    for turns in lima_data.conversations:
        multi_turn_conversations.append(" ".join(turns).replace("\n", " "))
        num_turns_per_conv.append(len(turns))
        languages_per_turn_in_conv.append(
            " ".join(
                [
                    f"{k}:{v}"
                    for k, v in dict(
                        Counter(
                            _detect_language(
                                texts=[turn.replace("\n", " ") for turn in turns],
                                lang_detect_model=lang_detect_model,
                            )
                        )
                    ).items()
                ]
            )
        )
    language_predictions = _detect_language(
        texts=multi_turn_conversations,
        lang_detect_model=lang_detect_model,
    )
    lima_data["num_turns"] = num_turns_per_conv
    lima_data["lang_FastText"] = language_predictions
    lima_data["lang_FastText_sentence"] = languages_per_turn_in_conv
    out_file_path = data_file_path.parent / (data_file_path.stem + "_with_stats.jsonl")
    save_as_json(entries=lima_data.to_dict(orient="records"), out_path=out_file_path)
    lima_data.to_csv(path_or_buf=out_file_path.with_suffix(".csv"), index=False)


def log_and_store_language_predictions(
    language_predictions: List[str], original_data_file_path: Path
) -> Dict[Any, int]:
    prediction_histogram = pd.Series(language_predictions).value_counts()
    logging.info(f"\n{prediction_histogram}")
    prediction_histogram.to_json(_get_lang_detect_file_path(original_data_file_path))
    return prediction_histogram.to_dict()


def _get_lang_detect_file_path(original_data_file_path: Path):
    if not original_data_file_path.exists() or original_data_file_path.is_dir():
        return Path(".")  # default
    return original_data_file_path.parent / (original_data_file_path.stem + "_fasttext_lang_pred.json")


def _detect_language(texts: Iterable[str], lang_detect_model: fasttext.FastText) -> List[str]:
    language_predictions = []
    for text in tqdm(texts, desc="Detect language"):
        prediction = lang_detect_model.predict(text, k=1)
        labels, probs = prediction
        language_pred = labels[0].rsplit("__label__", 1)[-1]
        language_predictions.append(language_pred)
    return language_predictions


def _detect_langauge_stackexchange(lang_detect_model: fasttext.FastText, data_file_path: Path) -> List[Dict]:
    file_paths = data_file_path.glob("*.jsonl")
    return [
        {
            file_path.stem: load_if_result_available(file_path)(_explore_stackexchange_per_file)(
                lang_detect_model=lang_detect_model, data_file_path=file_path
            )
        }
        for file_path in file_paths
    ]


def _explore_stackexchange_per_file(lang_detect_model: fasttext.FastText, data_file_path: Path):
    logging.info(f"Read data {data_file_path}...")
    data = pd.read_json(path_or_buf=data_file_path.__str__(), orient="records", lines=True)
    texts = (text.replace("\n", " ") for text in data.text)
    language_predictions = _detect_language(texts=texts, lang_detect_model=lang_detect_model)
    logging.info(f"{data_file_path}:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _detect_langauge_sharegpt(lang_detect_model: fasttext.FastText, data_file_path: Path):
    language_predictions, sharegpt_data = _explore_fastchat(lang_detect_model, data_file_path)

    logging.info("ShareGPT Raw:")
    logging.info("-------- Detected by Polyglot Detector:")
    log_and_store_language_predictions(
        language_predictions=sharegpt_data.language_code, original_data_file_path=data_file_path
    )
    logging.info("-------- Detected by FastText:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _explore_fastchat(lang_detect_model: fasttext.FastText, data_file_path: Path):
    fastchat_format_data = load_json_as_dataframe(data_file_path=data_file_path)
    multi_turn_conversations = (
        # replace is needed for FastText Language Detector
        " ".join([turn["value"] for turn in turns]).replace("\n", " ")
        for turns in fastchat_format_data.conversations
    )
    language_predictions = _detect_language(
        texts=multi_turn_conversations,
        lang_detect_model=lang_detect_model,
    )
    return language_predictions, fastchat_format_data


def _detect_langauge_oasst(lang_detect_model: fasttext.FastText, data_file_path: Path):
    language_predictions, _ = _explore_fastchat(lang_detect_model, data_file_path)
    logging.info("-------- Detected by FastText:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _detect_langauge_oig(lang_detect_model: fasttext.FastText, data_file_path: Path):
    if len(list(data_file_path.glob("*"))) == 0:
        snapshot_download(
            repo_id="laion/OIG",
            repo_type="dataset",
            local_dir=data_file_path,
            local_dir_use_symlinks=False,
        )

    data_dir_path = data_file_path
    file_paths = data_dir_path.glob("*.jsonl")
    return [
        {
            file_path.stem: load_if_result_available(file_path)(_explore_oig_per_file)(
                lang_detect_model=lang_detect_model, data_file_path=file_path
            )
        }
        for file_path in file_paths
    ]


def _explore_oig_per_file(lang_detect_model: fasttext.FastText, data_file_path: Path):
    texts = [entry["text"].replace("\n", " ") for entry in stream_jsonl_store_err_lines(data_file_path)]
    language_predictions = _detect_language(texts=texts, lang_detect_model=lang_detect_model)
    logging.info(f"{data_file_path}:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _detect_langauge_dolly(lang_detect_model: fasttext.FastText, data_file_path: Path):
    data = load_jsonl_as_dataframe(data_file_path=data_file_path)
    data["prompt"] = data.instruction.astype(str) + data.response
    language_predictions = []
    # not multi turn!
    language_predictions = _detect_language(
        texts=[x.replace("\n", " ") for x in data.prompt],
        lang_detect_model=lang_detect_model,
    )
    logging.info("Dolly:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )


def _detect_langauge_hh_rlhf(lang_detect_model: fasttext.FastText, data_file_path: Path):
    data = load_jsonl_as_dataframe(data_file_path=data_file_path)
    data["prompt"] = data.prompt.astype(str) + data.response
    language_predictions = []
    # not multi turn!
    language_predictions = _detect_language(
        texts=[x.replace("\n", " ") for x in data.prompt],
        lang_detect_model=lang_detect_model,
    )
    logging.info("HH-RLHF:")
    return log_and_store_language_predictions(
        language_predictions=language_predictions, original_data_file_path=data_file_path
    )
