from collections import defaultdict
from functools import wraps
from pathlib import Path

import pandas as pd

from align.config.app_config import AppConfig, DatasetName
from align.utils import (
    load_jsonl_as_dataframe,
)


def compute_datasets_stats(config: AppConfig):
    stats = {}
    for dataset_name in [
        DatasetName.lima,
        DatasetName.oasst,
        DatasetName.bactrianx,
        DatasetName.xP3mt,
        DatasetName.hh_rlhf,
        DatasetName.nectar,
    ]:
        dataset_path = config.data.dataset_name_to_path[dataset_name]
        compute_fun = get_compute_fun(dataset_name)
        stats.update(load_if_result_available(dataset_path, enable=True)(compute_fun)())

    stats = pd.DataFrame.from_dict(stats)
    print(stats)
    print(stats.T.to_latex())
    stats.T.to_markdown(Path("data/dataset_stats.md"), floatfmt=".2f")


def get_save_file_path(dataset_path: Path):
    stats_dir_path = dataset_path if dataset_path.is_dir() else dataset_path.parent
    stats_file_path = stats_dir_path.joinpath(f"{dataset_path.stem}_mean_num_words.csv")
    return stats_file_path


def get_compute_fun(dataset_name: DatasetName):
    match dataset_name:
        case DatasetName.lima:
            return compute_lima_stats
        case DatasetName.oasst:
            return compute_oasst_stats
        case DatasetName.bactrianx:
            return compute_bactrianx_stats
        case DatasetName.xP3mt:
            return compute_xP3mt_stats
        case DatasetName.hh_rlhf:
            return compute_hh_rlhf_stats
        case DatasetName.nectar:
            return compute_nectar_stats


def load_if_result_available(data_file_path: Path, enable: bool = True):
    def decorator(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            stats_file_path = get_save_file_path(data_file_path)
            if not enable or not stats_file_path.exists():
                args = tuple(list(args) + [data_file_path])
                stats = function(*args, **kwargs)
                stats = pd.DataFrame.from_dict(stats)
                save(stats, stats_file_path)
                return stats
            else:
                return pd.read_csv(stats_file_path, index_col=0)

        return wrapper

    return decorator


def save(stats: pd.DataFrame, stats_file_path: Path):
    stats = stats.round(2)
    with stats_file_path as f:
        stats.to_csv(f, index=True)


def get_mean_num_words(turns):
    return sum(map(lambda x: len(x.split(" ")), turns)) / len(turns)

def get_num_chars(turns):
    return sum(map(len, turns))


def compute_xP3mt_stats(data_dir_path: Path):
    stats = {}
    for data_file_path in data_dir_path.glob("*.jsonl"):
        lang = data_file_path.stem.split("_")[-1]
        data = load_jsonl_as_dataframe(data_file_path).fillna("")
        requests = data.inputs.tolist()
        responses = data.targets.tolist()

        stats.update(
            {
                f"xP3mt-{lang}": dict(
                    mean_num_words_requests=get_mean_num_words(requests),
                    mean_num_words_responses=get_mean_num_words(responses),
                    num_chars_requests=get_num_chars(requests),
                    num_chars_responses=get_num_chars(responses),
                )
            }
        )
    return stats


def compute_bactrianx_stats(data_dir_path: Path):
    stats = {}
    for data_file_path in data_dir_path.glob("*.jsonl"):
        lang = data_file_path.stem.split("_")[-1]
        data = load_jsonl_as_dataframe(data_file_path).fillna("")
        requests = (data.instruction.astype(str) + data.input).tolist()
        responses = data.output.tolist()

        stats.update(
            {
                f"bactrian-x-{lang}": dict(
                    mean_num_words_requests=get_mean_num_words(requests),
                    mean_num_words_responses=get_mean_num_words(responses),
                    num_chars_requests=get_num_chars(requests),
                    num_chars_responses=get_num_chars(responses),
                )
            }
        )
    return stats


def compute_hh_rlhf_stats(data_file_path: Path):
    stats = {}
    data = load_jsonl_as_dataframe(data_file_path).fillna("")
    requests = data.prompt.tolist()
    responses = data.response.tolist()

    stats.update(
        {
            "hh_rlhf": dict(
                mean_num_words_requests=get_mean_num_words(requests),
                mean_num_words_responses=get_mean_num_words(responses),
                num_chars_requests=get_num_chars(requests),
                num_chars_responses=get_num_chars(responses),
            )
        }
    )
    return stats


def compute_lima_stats(data_file_path: Path):
    data = load_jsonl_as_dataframe(data_file_path)
    requests = []
    responses = []
    for turns in data.conversations.tolist():
        for idx, turn in enumerate(turns):
            if idx % 2 == 0:
                requests.append(turn)
            else:
                responses.append(turn)
    return dict(
        lima=dict(
            mean_num_words_requests=get_mean_num_words(requests),
            mean_num_words_responses=get_mean_num_words(responses),
            num_chars_requests=get_num_chars(requests),
            num_chars_responses=get_num_chars(responses),
        )
    )


def compute_oasst_stats(data_file_path: Path):
    data = load_jsonl_as_dataframe(data_file_path)
    requests = []
    responses = []
    for turns in data.conversations.tolist():
        for idx, turn in enumerate(turns):
            if idx % 2 == 0:
                requests.append(turn)
            else:
                responses.append(turn)
    return dict(
        oasst=dict(
            mean_num_words_requests=get_mean_num_words(requests),
            mean_num_words_responses=get_mean_num_words(responses),
            num_chars_requests=get_num_chars(requests),
            num_chars_responses=get_num_chars(responses),
        )
    )

def compute_nectar_stats(data_file_path: Path):
    data = load_jsonl_as_dataframe(data_file_path)
    requests = []

    for turn in data.prompt.tolist():
        requests.append(turn)
    req_stats = dict(
        mean_num_words_requests=get_mean_num_words(requests),
        num_chars_requests=get_num_chars(requests),
    )

    model_responses = defaultdict(list)
    for model_answers in data.answers.tolist():
        for model_answer in model_answers:
            model_name = model_answer["model"]
            answer = model_answer["answer"]
            model_responses[model_name].append(answer)

    stats = {}    
    for model_name, responses in model_responses.items():
        stats.update({
            f"nectar_{model_name}": dict(
                **req_stats,
                mean_num_words_responses=get_mean_num_words(responses),
                num_chars_responses=get_num_chars(responses),
            )
        })
    return stats
    