from pathlib import Path
from typing import Dict, List
from matplotlib.transforms import Bbox

import numpy as np
import pandas as pd
import seaborn as sns
from fastchat.llm_judge.visualisation.render_win_tie_bar_chart import get_rename_map
from matplotlib import pyplot as plt


def get_lang(model_name):
    if "pre-train" in model_name:
        if model_name == "24EU-1T-pre-train":
            return "24EU-pre-train"
        return "pre-train"
    else:
        return model_name.split("-")[-1]


def get_dataset(model_name):
    if "Mitral" in model_name:
        return "Mistral"
    elif "pre-train" in model_name:
        if model_name == "24EU-1T-pre-train":
            return "24EU-1T"
        return "Mulima"
    else:
        return model_name.rsplit("-", 1)[0]


def read_data() -> pd.DataFrame:
    data = pd.DataFrame()
    # get data in for all MT-Bench-X
    for lang in "EN DE FR IT ES".split():
        mt_bench_dir_path = Path(f"{Path(__file__).parent}/../data/mt_bench_{lang}")
        mt_bench_file_path = mt_bench_dir_path / "model_judgment/gpt-4_single.jsonl"
        if mt_bench_file_path.exists():
            df = pd.read_json(mt_bench_file_path, lines=True, orient="records")
            df["bench_lang"] = lang
            data = pd.concat([data, df])
        else:
            print(f"Does not exist: {mt_bench_file_path.absolute()}")

    # rename model names
    rename_map = get_rename_map()
    for k, v in rename_map.items():
        data.replace(k, v, inplace=True)
    # fix GPT-4 scoring of 0 (range should be between 1 and 10)
    data.score = data.score.apply(lambda x: max(1, x))
    return data


def prep_data(data, do_turn_wise = True):
    turn = ["turn"] if do_turn_wise else []
    data = data[["model", "score", "bench_lang"] + turn]
    data["fine_tune_lang"] = data.model.apply(get_lang)
    data["fine_tune_data"] = data.model.apply(get_dataset)
    data = (
        data.groupby(["model", "fine_tune_lang", "bench_lang", "fine_tune_data"] + turn).score.mean().to_frame().reset_index()
    )
    data.drop(columns="model", inplace=True)
    data = data[(data.fine_tune_data == "24EU-1T-Bactrian-X") | (data.fine_tune_data == "24EU-1T-Mulima") | (data.fine_tune_data == "24EU-1T")]
    data.loc[data.iloc[-5:, data.columns.get_loc("fine_tune_data")].index,"fine_tune_data"] = "24EU-1T-Bactrian-X"
    data = pd.concat([data, data[-5:].copy(deep=True).reset_index(drop=True)], axis=0).reset_index(drop=True)
    data.loc[data.iloc[-5:, data.columns.get_loc("fine_tune_data")].index,"fine_tune_data"] = "24EU-1T-Mulima"
    data["fine_tune_data"] = data["fine_tune_data"].str.replace("24EU-1T-", "")
    data = data.reset_index(drop=True)
    data = data[(data["fine_tune_data"] == "Mulima") | (data["fine_tune_data"] == "Bactrian-X")]
    data.fine_tune_data = data.fine_tune_data.replace("Mulima", "Lima-X")
    return data

def get_colors():
    return {k: tuple(v/255 for v in values) for k,values in {
        "sampled": (52, 204, 240),
        "ENDEFRITES": (15, 134, 105),
        "DE": (212, 48, 31),
        "ES": (69, 173, 42),
        "IT": (255, 176, 0),
        "FR": (254, 97, 0),
        "EN": (120, 94, 240),
        "pre-train": (0,0,0)
    }.items()}

def viz_catplot(data, is_mulima, do_turn_wise = True, ) -> sns.FacetGrid:
    sns.reset_defaults()
    sns.set_style("whitegrid")
    plt.figure()
    extra_kwargs  = dict(x="turn") if do_turn_wise else {}
    data = data[data.fine_tune_lang != "24EU-pre-train"]
    g = sns.catplot(
        data,
        kind="bar",
        hue="fine_tune_lang",
        y="score",
        # row="fine_tune_data",
        col="bench_lang",
        height=2.5,
        aspect=0.7,
        sharey=True,
        sharex=False,
        margin_titles=True,
        legend=True,
        hue_order=["EN", "DE", "FR", "IT", "ES", "sampled", "ENDEFRITES"],
        col_order="EN DE FR IT ES".split(),
        palette=get_colors(),
        **extra_kwargs
    )
    g.despine(left=True)
    g.set_axis_labels("Turn", "Score")
    if not is_mulima:
        g.set_titles(col_template="MT-Bench-{col_name}", row_template="{row_name}")
        
    else:
        g.set_titles(col_template="", row_template="")
    g.set(ylim=(1, 5), yticks=np.linspace(1,5,9))
    if is_mulima:
        g.axes[0][0].set_yticklabels([1.0,1.5,2.0,2.5,3.0,"","","", ""])
        g.axes[0][0].set_yticks(g.axes[0][0].get_yticks()[:-4])
        
        for axis in g.axes[0]:
            axis.set_xticklabels(["Turn 1", "Turn 2"])
            axis.set(xlabel=None)
    g.figure.subplots_adjust(hspace=0.2)

    legend = _add_hatches_to_legend(g, bbox_to_anchor=(-1.85, 0.69), ncol=7)
    if not is_mulima:
        legend.remove()
    return g

def set_size(w,h, ax=None):
    """ w, h: width, height in inches """
    if not ax: ax=plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)

def _add_hatches_to_legend(g, bbox_to_anchor=(-1.85, 0.9), ncol=4):
    legend = plt.legend(title="Fine-tuning Language", loc="center", bbox_to_anchor=bbox_to_anchor, ncol=ncol)

    hatch_dict = {"ENDEFRITES": "xx", "sampled": "xx", "DE": "//", "EN": "//", "FR": "//", "IT": "//", "ES": "//"}
    color_lang_dict = {}
    color_hatch_dict = {}
    color_hatch_color_dict = {}
    # for legend
    for handle, label in zip(legend.legend_handles, legend.get_texts()):
        category = label.get_text()
        hatch = hatch_dict.get(category)
        handle.set_hatch(hatch)
        color = handle.get_facecolor()
        hatch_color = tuple([max(0, x - 0.3) for x in color[:-1]] + [1])
        handle._hatch_color = hatch_color
        color_hatch_dict[color] = hatch
        color_hatch_color_dict[color] = hatch_color
        color_lang_dict[color] = category

    for ax in g.axes.flat:
        for bars in ax.containers:
            hatches = [color_hatch_dict[patch.get_facecolor()] for patch in bars.patches]
            for bar, h in zip(bars, hatches):
                bar.set_hatch(h)
                color = bar.get_facecolor()
                bar._hatch_color = color_hatch_color_dict[color]
        ax.margins(y=0.2)
    g._legend.remove()
    return legend

def show_improvementes(data, lang_mix, do_turn_wise = True):
    turn = ["turn"] if do_turn_wise else []

    pivot = data.pivot(index=["bench_lang"] + turn, columns=["fine_tune_data", "fine_tune_lang"], values="score")
    percentages = pd.DataFrame()
    absolute_diff = pd.DataFrame()
    lang_codes = list(set(pivot.index.get_level_values(0).tolist())) if do_turn_wise else pivot.index.tolist()
    for lang_code in lang_codes:
        absolute_diff_per_lang = (pivot.loc[:, (slice(None), lang_mix)] - pivot.loc[:, (slice(None), lang_code)].values).rename(columns={lang_mix: lang_code})
        percentage_per_lang = (
             absolute_diff_per_lang / 10
        ).rename(columns={lang_mix: lang_code})
        percentages = pd.concat([percentages, percentage_per_lang], axis=1)
        absolute_diff = pd.concat([absolute_diff, absolute_diff_per_lang], axis=1)
    absolute_diff = absolute_diff.sort_index(axis=1) 
    percentages = percentages.sort_index(axis=1) * 100
    avgs = {}
    for key, df in zip(["abs", "percent"],[absolute_diff, percentages]):
        if do_turn_wise:
            avg = df.groupby(level=[1]).mean()
            df = pd.concat([df, avg], axis=0).rename(index={1: ("Avg.", 1), 2: ("Avg.", 2)})
            df.index = pd.MultiIndex.from_tuples(df.index.to_list())
        else:
            avg = df.mean(axis=0)
            df.loc["Avg."] = avg
            avg = pd.DataFrame(avg).T
            avg.index = pd.Index(["mean"], name="turn")
        print(df.to_latex(float_format="%.1f"))
        avgs[key] = avg
    return avgs

def viz_improvements(avgs: Dict[str, pd.DataFrame], do_turn_wise=True, is_abs=False):
    avg = pd.concat(list(avgs.values()), keys=avgs.keys())
    sns.set(font_scale=1.7)
    sns.set_style("whitegrid")
    
    
    for dataset_names in [["Bactrian-X", "Lima-X"]]:
        plt.figure()
        data = avg[dataset_names].copy()
        turn_variantions = [1,2,None] if do_turn_wise else [None]
        melt_ids = ["lang_mix", "turn"]
        multicol_melt_ids = [(x, "") for x in melt_ids]
        for turn in turn_variantions:
                melt = pd.melt(data.rename_axis(melt_ids, axis=0).reset_index(), id_vars=multicol_melt_ids).copy()
                melt = melt.rename(columns={x:x[0] for x in multicol_melt_ids})
                catplot_kwargs = {}
                if turn is not None:
                    melt = melt[melt.turn == turn]
                    catplot_kwargs.update(dict(row="turn"))
                else:
                    pass
                    # catplot_kwargs.update(dict(row="turn"))
                catplot_kwargs.update(dict(x="fine_tune_data") if isinstance(dataset_names, list) else {})
                g = sns.catplot(
                    melt, 
                    **catplot_kwargs,
                    kind="bar",
                    y="value",
                    col="lang_mix",
                    hue="fine_tune_lang",
                    hue_order=["", "EN", "DE", "FR", "IT", "ES", "", ""],
                    errorbar=None,
                    facet_kws=dict(gridspec_kws={"wspace":-0.07}),
                    width=1.3,
                    margin_titles=True,
                    palette=get_colors())
                legend = _add_hatches_to_legend(g, bbox_to_anchor=(0.0,0.87), ncol=3)
                plt.setp(legend.get_texts(), fontsize='14') # for legend text
                plt.setp(legend.get_title(), fontsize='16') # for legend title
                g.set_titles(col_template="{col_name}", row_template="Turn {row_name}")
                
                if is_abs:
                    g.set(ylim=(-0.4, 0.55))
                    g.set_axis_labels("", f"Abs. Improvement")
                else:
                    g.set(ylim=(-4, 5.5))
                    g.set_axis_labels("", f"Improvement [%]")
                
                for ax in g.axes.ravel():
                
                    # add annotations
                    for c in ax.containers:
                        if is_abs:
                            labels = [f'{(v.get_height()):.2f}'.lstrip('0').replace('-0', '-', 1) for v in c]
                        else:
                            labels = [f'{(v.get_height()):.1f}' for v in c]
                        ax.bar_label(c, labels=labels, label_type='edge', fontsize=14)
                    ax.margins(y=0.2)
                g.figure.tight_layout()
                if do_turn_wise:
                    if turn is not None:
                        turn_ref = f"turn_{turn}"
                    else:
                        turn_ref = f"turn_both"
                else:
                    turn_ref = f"turn_mean"
                
                if is_abs:
                    file_path = f"lang_mix_abs_improvement_{turn_ref}_{dataset_names if isinstance(dataset_names, str) else '_'.join(dataset_names)}.pdf"
                else:
                    file_path = f"lang_mix_improvement_{turn_ref}_{dataset_names if isinstance(dataset_names, str) else '_'.join(dataset_names)}.pdf"
                
                save_fig(g, Path(__file__).parent / file_path)

def save_fig(g: sns.FacetGrid, save_path: Path):
    g.savefig(save_path)


if __name__ == "__main__":
    
    # we cannot get the mean turn wise, as we would average over the realtive improvements of turns, which is different from calculating the relative improvment of averaged turns
    # hence with turn_wise=False, we calculate the mean and with turn_wise=True we calculate the rel. improvement per turn
    for do_turn_wise in [False, True]: 
        data = read_data()
        data = prep_data(data, do_turn_wise=do_turn_wise)
        lang_mixes = ["ENDEFRITES", "sampled"]
        avgs = {lang_mix: show_improvementes(data, do_turn_wise=do_turn_wise, lang_mix=lang_mix) for lang_mix in lang_mixes}
        viz_improvements({lang_mix: avgs[lang_mix]["percent"] for lang_mix in lang_mixes}, do_turn_wise=do_turn_wise, is_abs=False)
        viz_improvements({lang_mix: avgs[lang_mix]["abs"] for lang_mix in lang_mixes}, do_turn_wise=do_turn_wise, is_abs=True)

    data = read_data()
    do_turn_wise=True
    data = prep_data(data, do_turn_wise=do_turn_wise)
    for name, group in data.groupby("fine_tune_data"):
        g = viz_catplot(group, do_turn_wise=do_turn_wise, is_mulima= name == "Lima-X")
        save_fig(g, save_path=Path(__file__).parent / f"catplot_{name}.pdf")
