import itertools
import math

import pandas as pd
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from project_root import join_with_root

#df = pd.read_json(join_with_root("outputs/evaluation/new_test_few_shot.json"))
df = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot.json"))
df["kendall"] = df["kendall"].fillna(0)

# print top combinations
print(df[df["task"] == "en_de"].sort_values(by="kendall", ascending=False)[:5].to_string())
print(df[df["task"] == "zh_en"].sort_values(by="kendall", ascending=False)[:5].to_string())
print(df[df["task"] == "summarization"].sort_values(by="kendall", ascending=False)[:5].to_string())

l1 = np.sort(df["kendall"].to_list())
# l2 = np.sort(df2["kendall"].to_list())

plt.plot(l1, c="blue")
# plt.plot(l2, c="orange")
plt.savefig(join_with_root("outputs/plots/all_on_one_line.pdf"))
plt.show()

l1 = np.sort(df["kendall"].to_list())
# l2 = np.sort(df2["kendall"].to_list())

plt.plot(l1, c="blue")
# plt.plot(l2, c="orange")
plt.show()

coordinates = {
    "Neutral": {"x": 0, "y": 0},
    "Polite": {"x": 1, "y": -1},
    "Command": {"x": -1, "y": 1},
    "Threat": {"x": -2, "y": 2},
    "Urgent situation": {"x": -1, "y": 2},
    "Praising": {"x": 1, "y": -1},
    "Emphasis": {"x": 0, "y": 0},
    "Question": {"x": 0, "y": 0},
    "Provocative": {"x": 1, "y": -1},
    "Reward": {"x": 1, "y": -1},
    "Empathetic": {"x": 1, "y": -1},
    "Excited": {"x": 2, "y": 2},
    "Curious": {"x": 1, "y": -1},
    "Casual": {"x": 1, "y": -1},
    "Appreciative": {"x": 1, "y": -1},
    "Formal Request": {"x": 1, "y": -1},
    "Enthusiastic": {"x": 2, "y": 2},
    "Collaborative": {"x": 1, "y": -1},
    "Skeptical": {"x": -1, "y": 0},
    "Instructive": {"x": 1, "y": -1},
    "Encouraging": {"x": 1, "y": -1},
    "Strong Urgency": {"x": -2, "y": 2},
    "Pressing Matter": {"x": -2, "y": 2},
    "Serious Consequences": {"x": -2, "y": 2},
    "Immediate Action": {"x": -2, "y": 2},
    "Dire Warning": {"x": -2, "y": 2}
}

reverse_coordinates = {}
for name, coord in coordinates.items():
    x = coord["x"]
    y = coord["y"]
    if not (x, y) in reverse_coordinates:
        reverse_coordinates[(x, y)] = [name]
    else:
        reverse_coordinates[(x, y)] += [name]

plot_values = {}
for coord, names in reverse_coordinates.items():
    coord_value = []
    for name in names:
        current_df = df[df["task_description"] == name]
        coord_value.append(current_df["kendall"].max())
    # plot_values[coord] = sum(coord_value)/len(coord_value)
    plot_values[coord] = max(coord_value)

x = [k[0] for k in plot_values.keys()]
y = [k[1] for k in plot_values.keys()]
plt.scatter(x, y, s=(np.array(list(plot_values.values())) ** -8))
plt.axhline(0, color='black', linewidth=.5)
plt.axvline(0, color='black', linewidth=.5)
plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/emotion_influence.pdf"))
plt.show()

print(plot_values)

densities = {name: group["kendall"].tolist() for name, group in df.groupby(["task"])}
expl_density = pd.DataFrame(list(densities.values())).transpose()
expl_density.columns = list(densities.keys())
plot = sns.kdeplot(data=expl_density, bw_adjust=0.9, fill=True)
sns.move_legend(plot, "upper left", title='Dataset')
plt.axvline(0, color="black")
plot.set_yticklabels([])
plot.set_ylabel("")
plt.setp(plot.get_legend().get_texts(), fontsize='8')
plt.title("Kendall per LP")
plt.savefig(join_with_root("outputs/plots/kendall_per_lp.pdf"))
plt.show()

densities = {name: group["kendall"].tolist() for name, group in df.groupby(["model"])}
expl_density = pd.DataFrame(list(densities.values())).transpose()
expl_density.columns = list(densities.keys())
plot = sns.kdeplot(data=expl_density, bw_adjust=0.9, fill=True)
sns.move_legend(plot, "upper left", title='Dataset')
plt.axvline(0, color="black")
plot.set_yticklabels([])
plot.set_ylabel("")
plt.setp(plot.get_legend().get_texts(), fontsize='8')
plt.title("Kendall per model")
plt.savefig(join_with_root("outputs/plots/kendall_per_model.pdf"))
plt.show()

densities = {name: group["kendall"].tolist() for name, group in df.groupby(["prompt"])}
expl_density = pd.DataFrame(list(densities.values())).transpose()
expl_density.columns = list(densities.keys())
plot = sns.kdeplot(data=expl_density, bw_adjust=0.9, fill=True)
sns.move_legend(plot, "upper left", bbox_to_anchor=(1, 1))
plt.axvline(0, color="black")
plot.set_yticklabels([])
plot.set_ylabel("")
plt.setp(plot.get_legend().get_texts(), fontsize='6')
plt.title("Kendall per major prompt")
plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/kendall_per_major_prompt.pdf"))
plt.show()

densities = {name: group["kendall"].tolist() for name, group in df.groupby(["task_description"])}
expl_density = pd.DataFrame(list(densities.values())).transpose()
expl_density.columns = list(densities.keys())
plot = sns.kdeplot(data=expl_density, bw_adjust=0.9, fill=False)
sns.move_legend(plot, "upper left", bbox_to_anchor=(1, 1))
plt.axvline(0, color="black")
plot.set_yticklabels([])
plot.set_ylabel("")
plt.setp(plot.get_legend().get_texts(), fontsize='6')
plt.title("Kendall per minor prompt")
plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/kendall_per_minor_prompt.pdf"))
plt.show()

sns.set(rc={'figure.figsize': (12, 12)})

fig, axs = plt.subplots(3, 3)

combis = list(itertools.product(*[["Zero-Shot",
                                   "Zero-Shot-Cot",
                                   "Zero-Shot-Cot-Emotion"], df["task"].unique().tolist()]))
sub_dfs = [df[(df["prompt"] == d) & (df["task"] == t)].copy() for d, t in combis]

models = df["model"].unique()
colors = ["green", "blue", "red"]
cnt = 0

ordered = None


def save_max(i):
    try:
        a = max(i)
        if math.isnan(a):
            a = 0
        return max(i)
    except:
        return 0

order = False
correlations = {}
for y in range(3):
    for z in range(3):
        plots = []
        densities = {name: [save_max(group[group["model"] == m]["kendall"].tolist()) for m in models] for name, group in
                     sub_dfs[cnt].groupby(["task_description"])}
        if not ordered or order == False:
            densities = dict(sorted(densities.items(), key=lambda x: x[1][1]))
            ordered = list(densities.keys())
        if ordered:
            densities = {o: densities[o] for o in ordered}

        vals = {}
        for i, c in enumerate(colors):
            a = [d[i] for d in list(densities.values())]
            b = [f for f, x in enumerate(list(densities.keys()))]
            plots.append(axs[z, y].scatter(a, b, c=c))
            vals[models[i]] = a.copy()

        correlations[combis[cnt]] = vals.copy()

        axs[z, y].set_xlim([-0.3, 0.55])

        axs[z, y].set_title(combis[cnt])
        axs[z, y].set_yticks(range(len(densities.keys())), [d[0] for d in list(densities.keys())])

        cnt += 1
axs[0, 0].legend(plots,
                 df["model"].unique(),
                 scatterpoints=1,
                 loc='lower left',
                 fontsize=6)

plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/scatter_per_task_desc.pdf"))
plt.show()


sns.set(rc={'figure.figsize': (15, 15)})

fig, axs = plt.subplots(3, 3)

cnt = 0
for y in range(3):
    for z in range(3):
        corrs = []
        for model1 in models:
            corrs.append([])
            for model2 in models:
                corrs[-1].append(scipy.stats.pearsonr(correlations[combis[cnt]][model1], correlations[combis[cnt]][
                    model2]).statistic)

        cdf = pd.DataFrame(corrs)
        if y == 0:
            cdf.columns = models
        else:
            cdf.columns = [""]*len(models)
        cdf = cdf.T
        if z == 2:
            cdf.columns = models
        else:
            cdf.columns = [""]*len(models)

        if False:
            a = sns.heatmap(cdf, annot=True, ax=axs[z, y])
        else:
            a = sns.heatmap(cdf, annot=True, annot_kws={"fontsize": 20}, ax=axs[z, y], cbar = False, linewidths=.03)

        a.set_xticklabels(a.get_xticklabels(), verticalalignment='center',
                          horizontalalignment='center')
        a.set_yticklabels(a.get_yticklabels(), rotation=0)
        a.tick_params(axis='both', which='major', labelsize=18)
        axs[z, y].set_title(", ".join(combis[cnt]), fontsize = 20)
        a.set_xlabel("")
        a.set_ylabel("")

        cnt += 1
fig.suptitle("Model Performance Correlation Based on Emotion", fontsize = 28 ,fontweight='bold')
plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/task_desc_corr.pdf"))
plt.show()


sns.set(rc={'figure.figsize': (12, 12)})

fig, axs = plt.subplots(3, 3)

combis = list(itertools.product(*[["Zero-Shot",
                                   "Zero-Shot-Cot",
                                   "Zero-Shot-Cot-Emotion"], df["task"].unique().tolist()]))
sub_dfs = [df[(df["prompt"] == d) & (df["task"] == t)].copy() for d, t in combis]

models = df["model"].unique()
colors = ["green", "blue", "red"]
cnt = 0

ordered = None
correlations = {}
for y in range(3):
    for z in range(3):
        plots = []
        sub_dfs[cnt]["regex_name"] = pd.json_normalize(sub_dfs[cnt]['regex'])["name"].tolist()
        densities = {name: [save_max(group[group["model"] == m]["kendall"].tolist()) for m in models] for name, group in
                     sub_dfs[cnt].groupby(["regex_name"])}
        if not ordered:
            densities = dict(sorted(densities.items(), key=lambda x: x[1][0]))
            ordered = list(densities.keys())
        if ordered:
            densities = {o: densities[o] for o in ordered}

        vals = {}
        for i, c in enumerate(colors):
            a = [d[i] for d in list(densities.values())]
            b = [f for f, x in enumerate(list(densities.keys()))]
            plots.append(axs[z, y].scatter(a, b, c=c))
            vals[models[i]] = a.copy()

        correlations[combis[cnt]] = vals.copy()
        axs[z, y].set_title(combis[cnt])
        axs[z, y].set_yticks(range(len(densities.keys())), [d[0] for d in list(densities.keys())])
        axs[z,y].set_xlim([-0.3, 0.55])

        cnt += 1
axs[0, 0].legend(plots,
                 df["model"].unique(),
                 scatterpoints=1,
                 loc='lower left',
                 fontsize=6)

plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/scatter_by_format_prompt.pdf"))
plt.show()

sns.set(rc={'figure.figsize': (15, 15)})

fig, axs = plt.subplots(3, 3)

cnt = 0
for y in range(3):
    for z in range(3):
        corrs = []
        for model1 in models:
            corrs.append([])
            for model2 in models:
                corrs[-1].append(scipy.stats.kendalltau(correlations[combis[cnt]][model1], correlations[combis[cnt]][
                    model2]).statistic)

        cdf = pd.DataFrame(corrs)
        if y == 0:
            cdf.columns = models
        else:
            cdf.columns = [""]*len(models)
        cdf = cdf.T
        if z == 2:
            cdf.columns = models
        else:
            cdf.columns = [""]*len(models)

        if False:
            a = sns.heatmap(cdf, annot=True, ax=axs[z, y])
        else:
            a = sns.heatmap(cdf, annot=True, annot_kws={"fontsize": 20}, ax=axs[z, y], cbar = False, linewidths=.03)

        a.set_xticklabels(a.get_xticklabels(), verticalalignment='center',
                          horizontalalignment='center')
        a.set_yticklabels(a.get_yticklabels(), rotation=0)
        a.tick_params(axis='both', which='major', labelsize=18)
        axs[z, y].set_title(", ".join(combis[cnt]), fontsize = 20)
        a.set_xlabel("")
        a.set_ylabel("")

        cnt += 1
fig.suptitle("Model Performance Correlation Based on Format Prompts", fontsize = 28 ,fontweight='bold')
plt.tight_layout()
plt.savefig(join_with_root("outputs/plots/format_prompt_heatmap.pdf"))
plt.show()

print("DONE")
