import numpy as np
import pandas as pd
import scipy

import seaborn as sns
import matplotlib.pyplot as plt
from project_root import join_with_root


def gen_corr_heatmap(df, corr_measure="kendall", main_axis="regex"):
    sns.set(rc={'figure.figsize': (20, 20)})

    # How robust is the ranking of prompts given by the main axis across the sub-axis when the other dimensions are
    # aggregated by measure
    order = ["en_de", "en_es", "en_zh", "zh_en", "summ"]
    df["task"] = [d.replace("summarization", "summ") for d in df["task"].tolist()]

    for col in df.columns:
        if type(df[col].tolist()[0]) == dict:
            df[col] = [u["name"] for u in df[col].tolist()]

    res = {}
    for name, group in (df.groupby(by=[main_axis])):
        by = ["regex","task_description", "task", "format_prompt", "prompt", "model"]
        by.remove(main_axis)
        g = group.sort_values(by=by)
        res[name] = g[corr_measure]

    res = {k:[0 if np.isnan(j) else j for j in v] for k, v in res.items()}

    # Get general limits
    ymin = df[corr_measure].min()
    ymax = df[corr_measure].max()


    corr_dict = {u:[] for u in list(res.keys())}
    for k1, v1 in res.items():
        for k2, v2 in res.items():
            corr_dict[k1].append(scipy.stats.kendalltau(v1, v2, nan_policy="raise").statistic)

    corr_df = pd.DataFrame(corr_dict)
    corr_df = corr_df.rename(columns={"Nous":"N", "OpenOrca":"O", "Platypus70B":"P"})
    corr_df = corr_df.rename(columns={"Zero-Shot":"ZS", "Zero-Shot-Cot":"ZSC", "Zero-Shot-Cot-Emotion":"ZSEC"}).T
    corr_df.columns = [u[:4] for u in corr_df.T.columns]

    matrix = np.triu(np.ones_like(corr_df))
    np.fill_diagonal(matrix, False)
    a = sns.heatmap(corr_df, annot=True, annot_kws={"fontsize": 9}, cbar=False, linewidths=.03, mask=matrix)
    a.set_xticklabels(a.get_xticklabels(), verticalalignment='center',
                      horizontalalignment='center', rotation=80)
    a.set_yticklabels(a.get_yticklabels(), rotation=0)
    a.tick_params(axis='both', which='major', labelsize=9)
    a.tick_params(axis='x', which='major', pad = 10)


    plt.tight_layout()
    plt.show()

    return df

if __name__ == '__main__':
    train_files = [join_with_root("outputs/evaluation/corr_zero_shot_train_avg.json"),
                   ]#join_with_root("outputs/evaluation/corr_few_shot_train_avg.json")]
    df = pd.concat([pd.read_json(t) for t in train_files]).reset_index()
    gen_corr_heatmap(df, corr_measure="kendall", main_axis="task_description")