import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.stats import linregress
import seaborn as sns
from project_root import join_with_root

def fix_nan(l):
    return [0 if np.isnan(a) else a for a in l]

def compareZSOS(task, model):
    sns.set_theme()
    df1 = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot_train_avg_recombined.json"))
    #df1=df1[(df1["task"]==task)&(df1["model"]==model)]
    df2 = df1[df1["prompt"] == "Zero-Shot-Cot"]
    df3 = df1[df1["prompt"] == "Zero-Shot-Cot-Emotion"]
    df1 = df1[df1["prompt"] == "Zero-Shot"]

    df2["ID"] = df2["ID"].str.replace("Zero-Shot-Cot", "Zero-Shot")
    df3["ID"] = df3["ID"].str.replace("Zero-Shot-Cot-Emotion", "Zero-Shot")


    joint_df = pd.merge(left=df2,right=df1,on="ID")
    joint_df = pd.merge(left=df3, right=joint_df, on="ID")

    # k0 = CoT = orange , k1 = no CoT = blue
    ks = list(zip(fix_nan(joint_df["kendall_x"].tolist()), fix_nan(joint_df["kendall_y"].tolist()), fix_nan(joint_df[
                                                                                                                "kendall"].tolist())))
    ks = sorted(ks, key=lambda x:x[1])

    print(np.median([k[0] for k in ks]), np.mean([k[1] for k in ks]))
    plt.scatter(list(range(len(ks))), [k[0] for k in ks], s= 0.6,color="orange")
    plt.scatter(list(range(len(ks))), [k[1] for k in ks], s=0.6, color="blue")
    plt.scatter(list(range(len(ks))), [k[2] for k in ks], s=0.6, color="green")


    plt.plot([np.median([k[0] for k in ks])]*len(ks), color="red", linewidth=2)
    plt.plot([np.median([k[1] for k in ks])]*len(ks), color="blue", linewidth=2)
    plt.plot([np.median([k[2] for k in ks])]*len(ks), color="green", linewidth=2)


    reg = linregress(list(range(len(ks))), [k[0] for k in ks])
    plt.axline(xy1=(0, reg.intercept), slope=reg.slope, linestyle="--", linewidth=2, color="red")
    reg = linregress(list(range(len(ks))), [k[1] for k in ks])
    plt.axline(xy1=(0, reg.intercept), slope=reg.slope, linestyle="--", linewidth=2, color="blue")
    reg = linregress(list(range(len(ks))), [k[2] for k in ks])
    plt.axline(xy1=(0, reg.intercept), slope=reg.slope, linestyle="--", linewidth=2, color="green")
    plt.xlabel("Prompt ID")
    plt.ylabel("Kendall correlation")


    plt.tight_layout()
    plt.savefig(join_with_root(f"outputs/plots/CoTVS_train_ges.pdf"))
    plt.show()

if __name__ == '__main__':
    compareZSOS("en_de","Platypus70B")
    compareZSOS("zh_en","Platypus70B")
    compareZSOS("summarization", "Platypus70B")
    compareZSOS("en_de", "Nous")
    compareZSOS("zh_en", "Nous")
    compareZSOS("summarization", "Nous")
    compareZSOS("en_de", "OpenOrca")
    compareZSOS("zh_en", "OpenOrca")
    compareZSOS("summarization", "OpenOrca")