import numpy as np
import pandas as pd

from project_root import join_with_root

BASE_TABLE = '''\\begin{{table*}}[htb]
    \\centering
    \\begin{{tabular}}{{|l|l|l|l|l|l|}}
        \\hline
        \\textbf{{Model}} & \\textbf{{Prompt}} & \\textbf{{KD}} & \\textbf{{PE}} & \\textbf{{SP}} & \\textbf{{ACC}}\\\\
        {CONTENT}\\hline
    \\end{{tabular}}
    \\caption{{Best performing promts of the phase 2 evaluation on the Eval4NLP dev set. We present the \\textbf{{K}}en\\textbf{{D}}all, \\textbf{{SP}}earman and \\textbf{{PE}}arson, as well as the tie calibrated pair-wise \\textbf{{ACC}}uracy.}}
    \\label{{tab:bestPhase2}}
\\end{{table*}}'''

TASK_BLOCK = '''\\hline
\\textbf|||{TASKNAME}::: & & & & & \\\\
{SCORES}'''

def add_0(string):
    string = str(string)
    if len(string) == 0:
        return string
    if len(string) == 1:
        string += ".0"
    while len(string) < 5:
        string += "0"
    return string

def shortname(model):
    if "Platypus2-70B" in model or "Platypus70B" in model:
        return "\\textsc{Platypus2-70B}"
    if "OpenOrca" in model:
        return "\\textsc{OrcaPlt-13B}"
    if "Nous" in model:
        return "\\textsc{NousHermes-13B}"

def print_table(prex_df, baseline_df, dataset):
    order = ["en_de", "en_es", "en_zh", "zh_en", "summarization"]
    content_blocks = []
    for task, group in prex_df.groupby(["task"]):
        rows = []

        for inner_name, inner_group in group.groupby(["model"]):
            best_prompts = dict(inner_group.loc[inner_group['kendall'].idxmax()])

            kd = round(best_prompts["kendall"], 3)
            pe = round(best_prompts["pearson"], 3)
            if type(best_prompts["spearman"]) == list:
                sp = round(list(best_prompts["spearman"])[0], 3)
            else:
                sp = round(best_prompts["spearman"], 3)
            acc = round(best_prompts["kendall_tie_corrected"], 3)


            line = [shortname(best_prompts["name"][1]), best_prompts["prompt"].replace("Zero-Shot", "ZS").replace(
                "One-Shot", "OS").replace("emotion", "EM") + ", " + best_prompts[
                "task_description"]
                    +  ", " + best_prompts["regex"]["name"], kd, pe, sp, acc]
            if "FS" == best_prompts["mode"]:
                line[1] = line[1].replace("ZS", "OS")

            rows.append(line)

        for idx, row in baseline_df[(baseline_df["task"]==task[0]) & (baseline_df["dataset"]==dataset)].iterrows():
            # Filter for best baselines
            if "Platypus2-70B" in row["model"] and "DSBA" in row["approach"] or "DSBA" not in row["approach"]:
                if "OpenOrca" in row["model"] and "Gemba" in row["approach"] or "Gemba" not in row["approach"]:
                    model_block = ""
                    if "DSBA" in row["approach"] or "Gemba" in row["approach"]:
                        model_block = "Model:" + shortname(row["model"])
                    line = ["B:" + row["approach"][:10], model_block,round(row["kendall"],3),
                            round(row["pearson"],3), round(row["spearman"],3), round(
                                row["kendall_tie_corrected"],3)]

                    rows.append(line)

        for x in range(-4,0):
            top = np.argmax([r[x] for r in rows])
            top2 = np.argsort([np.max(r[x]) for r in rows])[-2]
            for i in range(len(rows)):
                rows[i][x] = add_0(str(rows[i][x]))
                if i == top or i == top2:
                    rows[i][x] = "\\textbf{" + rows[i][x] + "}"

        # Bold the best ones
        rows = [" & ".join(row) for row in rows]
        rows = "\\\\\n".join(rows)+"\\\\\n"


        content_blocks.append((task[0],TASK_BLOCK.format(TASKNAME=task[0].replace("_", "\\_"), SCORES=rows).replace(
            "|||","{").replace(":::","}")))

    content_blocks = sorted(content_blocks, key=lambda x: order.index(x[0]))
    table = BASE_TABLE.format(CONTENT="".join([c[1] for c in content_blocks]))
    print(table)

    return table


if __name__ == '__main__':
    baseline_df = pd.read_json(join_with_root("outputs/baseline_correlations/baseline_correlations.json"))

    prex_df_zero_shot = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot_train_avg_recombined.json"))
    prex_df_zero_shot["mode"] = "ZS"
    prex_df_few_shot = pd.read_json(join_with_root("outputs/evaluation/corr_few_shot_train_avg.json"))
    prex_df_few_shot["mode"] = "FS"
    prex_df = pd.concat([prex_df_zero_shot, prex_df_few_shot]).reset_index()
    print_table(prex_df, baseline_df, "train")

    print("\n\n\n----------------------------------\n\n\n")

    prex_df_zero_shot = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot_dev_avg.json"))
    prex_df_zero_shot["mode"] = "ZS"
    #prex_df_few_shot = pd.read_json(join_with_root("outputs/evaluation/new_test_few_shot.json"))
    #prex_df_few_shot["mode"] = "FS"
    prex_df = pd.concat([prex_df_zero_shot]).reset_index()
    print_table(prex_df, baseline_df, "dev")

    print("\n\n\n----------------------------------\n\n\n")

    prex_df_zero_shot = pd.read_json(join_with_root("outputs/evaluation/corr_zero_shot_test_avg.json"))
    prex_df_zero_shot["mode"] = "ZS"
    prex_df_few_shot = pd.read_json(join_with_root("outputs/evaluation/corr_few_shot_test_avg.json"))
    prex_df_few_shot["mode"] = "FS"
    prex_df = pd.concat([prex_df_zero_shot, prex_df_few_shot]).reset_index()
    prex_df = prex_df.drop_duplicates(subset=['kendall'])

    print_table(prex_df, baseline_df, "test")