import numpy as np
import pandas as pd
import scipy

import seaborn as sns
import matplotlib.pyplot as plt
from project_root import join_with_root


def gen_corr_heatmap(df, corr_measure="kendall", main_axis="regex", sub_axis="task", measure=np.median):
    sns.set(rc={'figure.figsize': (2, 2)})

    # How robust is the ranking of prompts given by the main axis across the sub-axis when the other dimensions are
    # aggregated by measure
    order = ["en_de", "en_es", "en_zh", "zh_en", "summ"]
    df["task"] = [d.replace("summarization", "summ") for d in df["task"].tolist()]

    if type(df[main_axis].tolist()[0]) == dict:
        df[main_axis] = [u["name"] for u in df[main_axis].tolist()]
    if type(df[sub_axis].tolist()[0]) == dict:
        df[sub_axis] = [u["name"] for u in df[sub_axis].tolist()]

    u1 = df[main_axis].unique()
    u2 = df[sub_axis].unique()
    try:
        pass
        u2 = sorted(u2, key=lambda x: order.index(x))
    except:
        pass


    # Get general limits
    ymin = df[corr_measure].min()
    ymax = df[corr_measure].max()

    res_dict = {u:[] for u in u2}
    main_dict = {u:[] for u in u2}

    # Iterate through each combination of main_axis and sub_axis
    for i, s1 in enumerate(u2):
        for j, s2 in enumerate(u1):
            # Filter data for the current model and task
            subset = df[(df[sub_axis] == s1) & (df[main_axis] == s2)]
            m = measure(subset[corr_measure])
            m = 0 if np.isnan(m) else m
            res_dict[s1].append(m)
            main_dict[s1].append(s2)

    corr_dict = {u:[] for u in u2}
    for i, s1 in enumerate(u2):
        for j, s2 in enumerate(u2):
            corr_dict[s1].append(scipy.stats.kendalltau(res_dict[s1], res_dict[s2], nan_policy="raise").statistic)

    corr_df = pd.DataFrame(corr_dict)
    corr_df = corr_df.rename(columns={"Nous":"N", "OpenOrca":"O", "Platypus70B":"P"})
    corr_df = corr_df.rename(columns={"Zero-Shot":"ZS", "Zero-Shot-Cot":"ZSC", "Zero-Shot-Cot-Emotion":"ZSEC"}).T
    corr_df.columns = [u[:4] for u in corr_df.T.columns]

    matrix = np.triu(np.ones_like(corr_df))
    np.fill_diagonal(matrix, False)
    a = sns.heatmap(corr_df, annot=True, annot_kws={"fontsize": 9}, cbar=False, linewidths=.03, mask=matrix)
    a.set_xticklabels(a.get_xticklabels(), verticalalignment='center',
                      horizontalalignment='center', rotation=80)
    a.set_yticklabels(a.get_yticklabels(), rotation=0)
    a.tick_params(axis='both', which='major', labelsize=9)
    a.tick_params(axis='x', which='major', pad = 10)


    plt.tight_layout()
    plt.savefig(join_with_root(f"outputs/plots/{main_axis}_along_{sub_axis}_{measure.__name__}_corr.pdf"))
    plt.show()

    return df

if __name__ == '__main__':
    train_files = [join_with_root("outputs/evaluation/corr_zero_shot_train_avg_recombined.json"),
                   ]#join_with_root("outputs/evaluation/corr_few_shot_train_avg.json")]
    df = pd.concat([pd.read_json(t) for t in train_files]).reset_index()
    gen_corr_heatmap(df, corr_measure="kendall", main_axis="regex", sub_axis="task",
                     measure=np.max)