from statistics import mean

from project_root import ROOT_DIR
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
from sklearn import linear_model

# This file contains methods to calculate the scores and results of our paper on Datasets without splits

def explore_all():
    # Produce general plots of results calculated by the "test_pmeans" experiment files.
    sns.set_theme()

    # Specifiy a dictionary with details for which combinations the experiments produced files with correlation scores
    explainer_names = ["ErasureExplainer", "LimeExplainer", "ShapExplainer"]
    alt_explainer_names = ["Erasure", "LIME", "SHAP"]
    ref_free = ["BERTSCORE_REF_FREE_XNLI", "XLMRSBERT"]

    config_dict = {
        "wmt17": {
            "lps": ["cs-en", "de-en", "fi-en", "lv-en", "ru-en", "tr-en", "zh-en"],
            "explainers": explainer_names,
            "metrics": ref_free + ["BERTSCORE_REF_BASED_ROBERTA_DEFAULT", "SentenceBLEU_REF_BASED"]
        },
        "eval4_nlp_21_test": {
            "lps": ["ro-en", "et-en", "ru-de", "de-zh"],
            "explainers": explainer_names,
            "metrics": ref_free
        },
        "wmt_22_expl_train": {
            "lps": ["ro-en", "et-en", "ru-en", "si-en", "ne-en", "en-zh", "en-de"],
            "explainers": explainer_names,
            "metrics": ref_free
        },
    }

    all_df = None
    unassigned = True

    for tag in config_dict.keys():
        lps = config_dict[tag]["lps"]
        explainers = config_dict[tag]["explainers"]
        metrics = config_dict[tag]["metrics"]

        for lp in lps:
            for explainer in explainers:
                for metric in metrics:
                    # Load df with resulting correlations for w and p values
                    filepath = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "experiment_results",
                                            "__".join([tag, explainer, metric, lp]) + ".tsv")
                    df = pd.read_csv(filepath, delimiter='\t')
                    df["explainer"] = [explainer] * len(df)
                    df["metric"] = [metric] * len(df)
                    df["lp"] = [lp] * len(df)
                    df["tag"] = [tag] * len(df)
                    orig = df[df["w"] == 1]["corr"].iat[0]

                    # Get difference to original for all of the values

                    df = df[df["w"] != 1]
                    df["diff"] = df[df["w"] != 1]["corr"] - orig
                    df["orig"] = [orig] * len(df["diff"])

                    if unassigned:
                        all_df = df
                        unassigned = False
                    else:
                        all_df = pd.concat([all_df, df])

    print("Number of overall datapoints: ", len(all_df))
    all_df = all_df[all_df["orig"] > 0]
    print("Number of overall datapoints: ", len(all_df))
    print(all_df.groupby("explainer").max())
    print(all_df.groupby("explainer").median())
    # Get Regressors
    all_df_dummy = pd.get_dummies(data=all_df, columns=["explainer", "metric", "lp", "tag"])

    cols = [c for c in all_df_dummy.columns if c not in ["Unnamed: 0", "corr", "diff"]]
    x = all_df_dummy[cols]
    y = all_df_dummy['diff']

    regr = linear_model.Ridge()
    regr.fit(x, y)

    for n, c in zip(cols, regr.coef_.tolist()):
        print(n, round(c, 3))
    print(regr.intercept_)

    # Plot distribution
    expl_density_list = []
    for e in explainer_names:
        expl_density_list.append(all_df[all_df["explainer"] == e]["diff"].tolist())

    expl_density_list = pd.DataFrame(expl_density_list).transpose()
    expl_density_list.columns = alt_explainer_names
    plot = sns.kdeplot(data=expl_density_list, bw_adjust=0.9, fill=True)
    sns.move_legend(plot, "upper left", title='Explainer')
    plt.axvline(0, color="black")
    plt.xlim(-0.4, 0.25)
    plot.set_yticklabels([])
    plot.set_ylabel("")
    plt.tight_layout()
    plt.savefig(
        "C:\\Users\\USERNAME\\PycharmProjects\\ExplainableMetrics\\xaiMetrics\\outputs\\Images_Paper_Auto_Gen\\3_All_Correlations.pdf")
    plt.show()


def find_best_p_and_w(only_hyp=False):
    # Configure datasets in the config_dict. These will be used to select the p and w values for BMX that will be
    # tested in the next method. Returns a dictionary with Metric: Explainer: p: , w: format
    sns.set_theme()
    explainer_names = ["ErasureExplainer", "LimeExplainer", "ShapExplainer"]  # , "InputMarginalizationExplainer"]
    alt_explainer_names = ["Erasure", "LIME", "SHAP"]  # , "IM"]
    # explainer_names = ["ErasureExplainer", "LimeExplainer", "ShapExplainer", "InputMarginalizationExplainer", "RandomExplainer"]
    ref_free = ["BERTSCORE_REF_FREE_XNLI", "XLMRSBERT", "XMoverScore_No_Mapping", "Transquest", "COMET"]
    # ref_free = ["BERTSCORE_REF_FREE_XNLI", "XLMRSBERT", "XMoverScore_No_Mapping", "RandomScore"]
    ref_based = ["BERTSCORE_REF_BASED_ROBERTA_DEFAULT", "SentenceBLEU_REF_BASED"]

    config_dict = {
        "wmt17": {
            "lps": ["cs-en", "de-en", "fi-en", "lv-en", "ru-en", "tr-en", "zh-en"],
            "explainers": explainer_names,
            "metrics": ref_free + ref_based
        },
        "eval4_nlp_21_test": {
            "lps": ["ro-en", "et-en", "ru-de", "de-zh"],
            "explainers": explainer_names,
            "metrics": ref_free
        },
        "wmt_22_expl_train": {
            "lps": ["ro-en", "et-en", "ru-en", "si-en", "ne-en", "en-zh", "en-de"],
            "explainers": explainer_names,
            "metrics": ref_free
        }
    }

    lp_l = []
    expl_l = []
    metric_l = []
    diff_l = []
    p_l = []
    w_l = []
    orig_l = []
    max_l = []
    tag_l = []

    for tag in config_dict.keys():
        lps = config_dict[tag]["lps"]
        explainers = config_dict[tag]["explainers"]
        metrics = config_dict[tag]["metrics"]

        for lp in lps:
            for explainer in explainers:
                for metric in metrics:
                    # Load df with resulting correlations for w and p values
                    filepath = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "experiment_results",
                                            "__".join([tag, explainer, metric, lp]) + ".tsv")
                    if only_hyp:
                        filepath = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "experiment_results",
                                                "__".join([tag, explainer, metric, lp, "only_hyp"]) + ".tsv")

                    try:
                        df = pd.read_csv(filepath, delimiter='\t')
                    except Exception as e:
                        print(str(e) + "This is expected for Transquest and COMET with other explainers")
                        continue

                    orig = df[df["w"] == 1]["corr"].iat[0]  # Determine the original correlation
                    # Maximum correlation for dataset, lp, explainer, metric combination
                    p_best = df[df["corr"] > orig]["p"].tolist()
                    w_best = df[df["corr"] > orig]["w"].tolist()

                    for w, p in zip(w_best, p_best):
                        lp_l.append(lp)
                        tag_l.append(tag)
                        expl_l.append(explainer)
                        metric_l.append(metric)
                        diff_l.append(df[(df["p"] == p) & (df["w"] == w)] - orig)
                        p_l.append(p)
                        w_l.append(w)
                        orig_l.append(orig)

    print("Settings with improvement:", len(lp_l))

    # Build a pandas df of the result values for each metric, explainer, lp combination
    res_df = pd.DataFrame([tag_l, lp_l, expl_l, metric_l, diff_l, p_l, w_l, orig_l, max_l]).transpose()
    res_df.columns = ["tag", "lp", "explainer", "metric", "diff", "p", "w", "orig", "max"]
    res_df = res_df[res_df["orig"] > 0]
    print("Settings with improvement:", len(res_df))

    # Create Box Plots and get medians from them
    p_w_per_metric = {}
    for metric_name in ref_free + ref_based:
        fig, axs = plt.subplots(2, len(explainer_names))
        p_w_dict = {}
        for x, explainer in enumerate(explainer_names):
            explainer_df = res_df[(res_df["explainer"] == explainer) & (res_df["metric"] == metric_name)]
            if len(explainer_df) == 0:
                continue
            _, w_dict = explainer_df["w"].plot.box(ax=axs[0, x], return_type="both")
            _, p_dict = explainer_df["p"].plot.box(ax=axs[1, x], return_type="both")
            p_w_dict[explainer] = {"p": round(p_dict['medians'][0].get_ydata()[0], 3),
                                   "w": round(w_dict['medians'][0].get_ydata()[0], 3)}
            axs[0, x].set_ylim(-0.1, 1.1)
            axs[1, x].set_ylim(-34, 34)
            axs[0, x].set_xticklabels([])
            axs[1, x].set_xticklabels([])
            if x != 0:
                axs[0, x].set_yticklabels([])
                axs[1, x].set_yticklabels([])
            else:
                axs[0, x].set_ylabel("w")
                axs[1, x].set_ylabel("p")
            axs[0, x].set_title('Md: ' + str(p_w_dict[explainer]["w"]), fontsize=8)
            axs[1, x].set_title('Md: ' + str(p_w_dict[explainer]["p"]), fontsize=8)
            axs[0, x].annotate(alt_explainer_names[x], xy=(0.5, 1), xytext=(0, 20),
                               xycoords='axes fraction', textcoords='offset points',
                               size='medium', ha='center', va='baseline')

        fig.subplots_adjust(hspace=0.3)
        plt.tight_layout()
        plt.savefig(
            "C:\\Users\\USERNAME\\PycharmProjects\\ExplainableMetrics\\xaiMetrics\\outputs\\Images_Paper_Auto_Gen\\4_best_p_and_w.pdf")
        plt.show()

        print("Best p and w dict (medians):", p_w_dict)
        p_w_per_metric[metric_name] = p_w_dict

    return p_w_per_metric

def evaluate_p_w_best_fix_2(p_w_dict, w_threshold=0.01, p_threshold=0.01, only_hyp=False, wmt22=False, mqm=False):
    print(p_w_dict)
    sns.set_theme()
    explainer_names = ["ErasureExplainer", "LimeExplainer", "ShapExplainer"]  # , "InputMarginalizationExplainer"]
    alt_explainer_names = ["Erasure", "LIME", "SHAP"]  # , "IM"]
    ref_free = ["BERTSCORE_REF_FREE_XNLI", "XLMRSBERT", "Transquest", "COMET"]
    ref_based = ["BERTSCORE_REF_BASED_ROBERTA_DEFAULT", "SentenceBLEU_REF_BASED"]

    # Applies various plots on the test sets when selected p and w values are applied

    config_dict = {
        "wmt17": {
            "lps": ["cs-en", "de-en", "fi-en", "lv-en", "ru-en", "tr-en", "zh-en"],
            "explainers": explainer_names,
            "metrics": ref_free + ref_based
        },
        "eval4_nlp_21_test": {
            "lps": ["ro-en", "et-en", "ru-de", "de-zh"],
            "explainers": explainer_names,
            "metrics": ref_free
        },
        "wmt_22_expl_train": {
            "lps": ["ro-en", "et-en", "ru-en", "si-en", "ne-en", "en-zh", "en-de"],
            "explainers": explainer_names,
            "metrics": ref_free
        },
    }
    alt_tags = ["WMT17", "Eval4NLP21", "MLQE-PE"]

    if wmt22:
        config_dict = {"wmt_22_test_sets-spearman-seg": {
            "lps": ["en-cs", "en-ja", "en-mr", "en-yo", "km-en", "ps-en"],
            "explainers": explainer_names,
            "metrics": ref_free
        }}
        alt_tags = ["WMT22"]

        # Uncomment this for the MQM analysis

    if mqm:
        config_dict = {"mqm21": {
            "lps": ["en-de", "zh-en"],
            "explainers": ["LimeExplainer"],
            "metrics": ref_free
        }}
        alt_tags = ["MQM"]

    # ------------------- Get dataframe with fixed p and w selection ------------------------
    # Fix w and p values based on the previous analysis can be set as function arguments
    lp_l = []
    expl_l = []
    metric_l = []
    diff_l = []
    orig_l = []
    new_corr = []
    tag_l = []
    perc_l = []

    max_non_fix = []
    max_perc_non_fix = []

    for tag in config_dict.keys():
        lps = config_dict[tag]["lps"]
        explainers = config_dict[tag]["explainers"]
        metrics = config_dict[tag]["metrics"]

        for lp in lps:
            for explainer in explainers:
                for metric in metrics:
                    # Load df with resulting correlations for w and p values
                    filepath = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "experiment_results",
                                            "__".join([tag, explainer, metric, lp]) + ".tsv")
                    if only_hyp:
                        filepath = os.path.join(ROOT_DIR, "xaiMetrics", "outputs", "experiment_results",
                                                "__".join([tag, explainer, metric, lp, "only_hyp"]) + ".tsv")
                    try:
                        df = pd.read_csv(filepath, delimiter='\t')
                    except Exception as e:
                        print(str(e) + "This is expected for Transquest and COMET with other explainers")
                        continue

                    fix_p = p_w_dict[metric][explainer]["p"]
                    fix_w = p_w_dict[metric][explainer]["w"]

                    # find closest to given medians
                    p_condition = abs(
                        df["p"] - min(np.arange(-30, 30, 0.1), key=lambda x: abs(x - fix_p))) < p_threshold
                    w_condition = abs(
                        df["w"] - min([0, 0.2, 0.4, 0.6, 0.8, 1], key=lambda x: abs(x - fix_w))) < w_threshold

                    selected_correlation = df[p_condition & w_condition]["corr"].tolist()[0]
                    orig = df[df["w"] == 1]["corr"].iat[0]  # Determine the original correlation
                    diff = selected_correlation - orig  # Determine the difference between all correlations and the original

                    try:
                        max_non_fix.append(max([(a, b, c) for a, b, c in list(
                            zip(np.array(df["corr"].tolist()) - orig, df["corr"].tolist(),
                                [orig] * len(df["corr"].tolist())))
                                                if c > 0], key=lambda item: item[0]))
                    except:
                        pass
                    max_perc_non_fix.append(
                        max([((s / orig) * 100, orig, s) if s > 0 and orig > 0 else (0, orig, s) for s in
                             df["corr"].tolist()],
                            key=lambda item: item[0]))

                    lp_l.append(lp)
                    tag_l.append(tag)
                    expl_l.append(explainer)
                    metric_l.append(metric)
                    diff_l.append(diff)
                    orig_l.append(orig)
                    new_corr.append(selected_correlation)

                    percentages = []
                    if selected_correlation >= 0 and orig > 0:
                        percentages.append((diff / orig) * 100)
                    else:
                        percentages.append(None)
                    if len(percentages) == 1:
                        percentages = percentages[0]
                    perc_l.append(percentages)

    print("Max list: ", max(max_non_fix, key=lambda x: x[0]), max_non_fix)
    print("Max perc list: ", max(max_perc_non_fix), max_perc_non_fix)
    # Build a pandas df of the result values for each metric, explainer, lp combination
    # Uses fix p and w values from function definition
    fix_df = pd.DataFrame([tag_l, lp_l, expl_l, metric_l, diff_l, orig_l, new_corr, perc_l]).transpose()
    fix_df.columns = ["tag", "lp", "explainer", "metric", "diff", "orig", "new_correlation", "improvement_percent"]
    print(fix_df)

    # Generate Latex Tables
    alt_names = {"BERTSCORE_REF_FREE_XNLI":"XBERTScore",
                 "XLMRSBERT":"XLMR-SBERT",
                 "Transquest":"TRANSQUEST",
                 "COMET":"COMET",
                 "BERTSCORE_REF_BASED_ROBERTA_DEFAULT":"BERTScore",
                 "SentenceBLEU_REF_BASED":"SentenceBLEU"}


    # Produce tables for the paper
    for tag in config_dict.keys():
        print("\\begin{table*}\n\\centering\\small\n\\begin{tabular}{l|" + "".join("c" * len(config_dict[tag]["metrics"]) + "}\\toprule"))
        print("LP & " + " & ".join(
            ["\\textbf{" +alt_names[m] + "}" for m in config_dict[tag]["metrics"]]))
        print("\\\\\\midrule")
        scores_per_metric = {metric:[] for metric in config_dict[tag]["metrics"]}

        for lp in config_dict[tag]["lps"]:
            line_str = lp + " & "
            for metric in config_dict[tag]["metrics"]:
                start = True
                scores = []
                for explainer in config_dict[tag]["explainers"]:
                    selection = fix_df[(fix_df["tag"] == tag) & (fix_df["metric"] == metric) & (fix_df["lp"] == lp) & (
                            fix_df["explainer"] == explainer)]
                    if len(selection) == 0:
                        continue
                    if start:
                        scores.append(round(selection["orig"].tolist()[0], 3))
                        start = False
                    scores.append(round(selection["new_correlation"].tolist()[0], 3))
                max_score_indices = np.argwhere(scores == np.amax(scores)).flatten().tolist()
                #if not 0 in max_score_indices:
                #    line_str += "\\tablegreen{"
                #else:
                #    line_str += "\\tablered{"
                for i, score in enumerate(scores):
                    val = str(score)
                    if i in max_score_indices:
                        if (0 in max_score_indices and i == 0) or not 0 in max_score_indices:
                            val = "\\textbf{" + val + "}"
                    if i != 0 and score > scores[0]:
                        val = "\\tablegreen{" + val + "}"
                    val = "$"+val+"$"
                    if i != 0:
                        val = "/" + val
                    line_str += val
                line_str += " & "
                scores_per_metric[metric].append(scores)
            line_str = line_str[:-2] + "\\\\"
            print(line_str)

        avg = []
        for value in scores_per_metric.values():
            avg.append([])
            for x in range(len(value[0])):
                avg[-1].append(round(mean([value[y][x] for y in range(len(value))]),3))
        #averages = [[str(mean(v)) for v in list(zip(value))] for value in scores_per_metric.values()]

        avg2 = []
        for a in avg:

            max_score_indices = np.argwhere(a == np.amax(a)).flatten().tolist()
            #if not 0 in max_score_indices:
            #    avg2.append("\\tablegreen{")
            #else:
            #    avg2.append("\\tablered{")
            avg2.append("")
            for i, score in enumerate(a):
                val = str(score)
                if i in max_score_indices:
                    if (0 in max_score_indices and i == 0) or not 0 in max_score_indices:
                        val = "\\textbf{" + val + "}"
                if i != 0:
                    avg2[-1]+="/"
                    if score > a[0]:
                        val = "\\tablegreen{" + val + "}"
                avg2[-1] += "$" + val + "$"

        line_str = "AVG & " + " & ".join(avg2) + "\\\\"
        print(line_str)

        print("\\bottomrule\n\\end{tabular}\\end{table*}")

    # One hot encode and perform regression of most important parameters
    fix_df_dummy = pd.get_dummies(data=fix_df, columns=["explainer", "metric", "lp", "tag"])

    cols = [c for c in fix_df_dummy.columns if
            c not in ["Unnamed: 0", "corr", "diff", "new_correlation", "orig", "improvement_percent"]]
    x = fix_df_dummy[cols]
    y = fix_df_dummy['diff']
    y = np.array(y)

    regr = linear_model.Ridge()
    regr.fit(x, y)

    tpls = []
    for n, c in zip(cols, regr.coef_.tolist()):
        tpls.append((n, c))

    # Only display 3 best and worst
    tpls = sorted(tpls, key=lambda x: x[1], reverse=True)
    tpls = tpls[:3] + tpls[-3:]
    print("Ordered Regression Weights:", tpls)
    print(regr.intercept_)
    cols = [e[0] for e in tpls]
    rename_dict = {"lp_en-cs": "en-cs",
                   "explainer_LimeExplainer": "LIME",
                   "metric_BERTSCORE_REF_FREE_XNLI": "XBERTScore",
                   "metric_XMoverScore_No_Mapping": "XMoverScore",
                   "explainer_ErasureExplainer": "Erasure",
                   "explainer_InputMarginalizationExplainer": "IM",
                   "lp_en-yo": "en-yo",
                   "lp_de-zh": "de-zh",
                   "lp_tr-en": "tr-en",
                   "tag_wmt_22_expl_train": "MLQE-PE",
                   "lp_si-en": "si-en",
                   "lp_en-zh": "en-zh",
                   "lp_en-de": "en-de",
                   "lp_de-en": "de-en",
                   "lp_zh-en": "zh-en",
                   "lp_ro-en": "ro-en",
                   "lp_km-en": "km-en",
                   "metric_BERTSCORE_REF_BASED_ROBERTA_DEFAULT": "BERTScore",
                   "tag_mqm21": "mqm21",
                   "tag_wmt17": "wmt17",
                   "metric_XLMRSBERT": "XLMR-SBERT",
                   "metric_Transquest": "Transquest",
                   "metric_COMET": "COMET",
                   "lp_en-mr":"en-mr",
                   "lp_ru-de":"ru-de",
                   "explainer_ShapExplainer":"SHAP"}
    cols = [rename_dict[c] for c in cols]
    weights = [e[1] for e in tpls]
    y_pos = np.arange(len(weights))
    positive = np.array(["b" if w >= 0 else "r" for w in weights])
    sns.set(rc={'figure.figsize': (3.5, 1.5)})

    plt.barh(y_pos, weights, color=positive)
    plt.yticks(y_pos, cols)
    plt.tight_layout()
    plt.savefig(
        "C:\\Users\\USERNAME\\PycharmProjects\\ExplainableMetrics\\xaiMetrics\\outputs\\Images_Paper_Auto_Gen\\12_regressors.pdf")
    plt.show()


if __name__ == '__main__':
    # sns.set(rc={'figure.figsize': (3.15, 2.5)})
    # explore_all()
    sns.set(rc={'figure.figsize': (3.15, 3)})
    pw_dict = find_best_p_and_w()
    sns.set(rc={'figure.figsize': (3.15, 2)})
    evaluate_p_w_best_fix_2(pw_dict, wmt22=True, mqm=False)
