import os
import math
import numpy as np
import pandas as pd
import torch
from itertools import combinations
import matplotlib.pyplot as plt

fname2results = {}
iterations = None

gtypes = ["diff", "pre", "post"]

for gtype in gtypes:

    data_dir = "downstream/logic/output/implicit_rule/"
    data_dir = "downstream/logic/output/statement_start/"
    data_dir = "downstream/logic/output/rank/statement/"

    dfs = {}

    results_fnames = ["res_256.npy", "res_256_unused.npy", "res_256_wordnet.npy", "res_256_negation.npy"]
    results_fnames = ["test.npy", "test_unused.npy", "test_wordnet.npy"]
    results_fnames = ["test.npy", "test_wordnet.npy"]
    results_fnames = ["test.npy", "test_unused.npy", "test_wordnet.npy"]



    for fname in results_fnames:

        out_dir = data_dir + fname

        fin_dir = f"{data_dir}{gtype}/{os.path.splitext(os.path.basename(out_dir))[0]}"
        if not os.path.exists(fin_dir):
                os.makedirs(fin_dir)

        rule_types = ["statement", "property", "implicit_rule"]
        training_sets = [set for i in range(len(rule_types)) for set in combinations(rule_types, i) ]
        training_sets = [["statement"]]

        res = np.load(out_dir)
        res = res * -1

        formatted_res = np.zeros((20, 21, 2, 5))

        # original (21, 5, 2, 3)
        formatted_res = np.zeros((20, 3, 2, 5))

        for i in range(1, 21):
            for j in range(1):
                for k in range(3):
                    for l in range(2):
                        for m in range(5):
                            formatted_res[i-1, j*3+k, l, m] = -1 / res[i-1, m, j*2+l, k]

        res = formatted_res
        # res_error = res.std(axis=-1)

        pre = res[:, :, 0, :]
        post = res[:, :, 1, :]

        if gtype == 'pre':
            res = pre
        elif gtype == 'post':
            res = post
        else:
            res = post-pre

        res_error = res.std(axis=-1) * 0.33
        res = res.mean(axis=-1)

        for test_split in ["test:statement", "test:property", "test:implicit_rule"]:

                headers = []
                for train_set in training_sets:
                    col = "train:" + ":".join([rule for rule in train_set])
                    for rule in rule_types:
                        headers.append(col + " \ test:" + rule)

                df = pd.DataFrame(res, columns=headers)
                df_error = pd.DataFrame(res_error, columns=headers)

                # test_split = "test:statement"

                df2 = df.filter(like=test_split)
                df2_error = df_error.filter(like=test_split)

                # df3 = df2.sub(df2["train:statement \ " + test_split].to_numpy(), axis=0)
                # df3.to_csv(f'{fin_dir}/normal_{test_split.replace("test:", "")}.csv')

                iterations = [0] + list(range(49999, 1000000, 50000))[:-1]
                # iterations = list(range(49999, 1000001, 50000))
                df2.insert(0, "Iterations", iterations)
                df2_error.insert(0, "Iterations", iterations)

                # exit(1)

                # dropped = []
                # for ci, col in enumerate(df2.columns):
                #     if 'train: ' in col:
                #         dropped.append(col)
                #     train = col.split(' \\ ')[0].split(':')[1:]
                #     if len(train) == 2 and any([t.rstrip() in test_split for t in train]):
                #         dropped.append(col)
                
                # df2 = df2.drop(columns=dropped)
                # # assert some columns were dropped
                # assert len(df2.columns) != 8

                # plt.figure()
                # df2.plot(x="Iterations")
                # plt.savefig(f'{fin_dir}/nll_{test_split.replace("test:", "")}.pdf', dpi=300)
                # plt.close()

                plt.figure()

                # plt.errorbar(df2["Iterations"], df2[df2.columns[1]])
                if not iterations:
                    iterations = df2["Iterations"]

                plt.errorbar(df2["Iterations"], df2[df2.columns[1]], yerr=df2_error[df2_error.columns[1]])

                plt.savefig(f'{fin_dir}/nll_{test_split.replace("test:", "")}.png', dpi=300)

                plt.close()

                df2.to_csv(f'{fin_dir}/nll_{test_split.replace("test:", "")}.csv')

                dfs[fname] = df2

                fname2results[gtype+" "+fname+" "+test_split] = [df2[df2.columns[1]], df2_error[df2_error.columns[1]]]


iterations = [str(int( (x+1) / 1000 )) + 'K' for x in iterations]


plt.figure()

plt.rcParams.update({'font.size': 18})

plt.rcParams.update({'figure.autolayout': True})

a, b = fname2results["diff test.npy test:implicit_rule"]

plt.errorbar(iterations, a, yerr=b, label="Class-relation", color="tab:orange")

a, b = fname2results["diff test_wordnet.npy test:implicit_rule"]

plt.errorbar(iterations, a, yerr=b, label="Class-relation Control", color="tab:cyan", ls="--")

plt.legend()

plt.ylabel("Change in MRR")

plt.xlabel("Training Iteration")

plt.tight_layout()

plt.xticks(np.arange(len(iterations))[::4], iterations[::4])

plt.savefig(f'{data_dir}custom.png', dpi=300)
plt.savefig(f'{data_dir}relation_mrr.pdf', dpi=300)

print(f'{data_dir}custom.png')

plt.close()