import numpy as np
import scipy

from project_root import join_with_root
import pandas as pd


def correlate(df1, df2):
    # Get options from smaller df2
    try:
        df1["regex"] = [d["name"] for d in df1["regex"].tolist()]
    except:
        pass
    try:
        df2["regex"] = [d["name"] for d in df2["regex"].tolist()]
    except:
        pass
    regex = df2["regex"].unique().tolist()
    task_desc = df2["task_description"].unique().tolist()
    prompt = df2["prompt"].unique().tolist()
    model = df2["model"].unique().tolist()

    joint_dict = {"regex": [], "task_desc": [], "prompt": [], "model": [], "kd_1":[], "kd_2":[]}
    for r in regex:
        for t in task_desc:
            for p in prompt:
                for m in model:
                    df1_s = df1[(df1["regex"]==r)&(df1["task_description"]==t)&(df1["prompt"]==p)&(df1["model"]==m)]
                    df2_s = df2[(df2["regex"]==r)&(df2["task_description"]==t)&(df2["prompt"]==p)&(df2["model"]==m)]
                    kd_1 = df1_s["kendall"].tolist()
                    kd_2 = df2_s["kendall"].tolist()
                    if len(kd_2) == 0:
                        continue
                    joint_dict["kd_1"].append(np.mean(kd_1))
                    joint_dict["kd_2"].append(np.mean(kd_2))
                    joint_dict["regex"].append(r)
                    joint_dict["task_desc"].append(t)
                    joint_dict["prompt"].append(p)
                    joint_dict["model"].append(m)

    joint_dict["kd_1"] = [0 if np.isnan(j) else j for j in joint_dict["kd_1"]]
    joint_dict["kd_2"] = [0 if np.isnan(j) else j for j in joint_dict["kd_2"]]

    print(scipy.stats.kendalltau(joint_dict["kd_1"], joint_dict["kd_2"], nan_policy="raise").statistic)




if __name__ == '__main__':
    train_files = [join_with_root("outputs/evaluation/corr_zero_shot_train_avg_no_emotion.json")]
    train_df = pd.concat([pd.read_json(t) for t in train_files]).reset_index()

    dev_files = [join_with_root("outputs/evaluation/corr_zero_shot_dev_avg_no_emotion.json")]
    dev_df = pd.concat([pd.read_json(t) for t in dev_files]).reset_index()

    test_files = [join_with_root("outputs/evaluation/corr_zero_shot_test_avg_no_emotion.json")]
    test_df = pd.concat([pd.read_json(t) for t in test_files]).reset_index()

    correlate(train_df, dev_df)
    correlate(train_df, test_df)
    correlate(dev_df, test_df)
