
import matplotlib.pyplot as plt
import numpy as np


plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.size": 12,
    "figure.figsize": (6, 4),
    "axes.titlesize": 12,
    "axes.labelsize": 14,
    "legend.fontsize": 12,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.direction": "inout",
    "ytick.direction": "inout",
    "lines.linewidth": 2.0,
    "lines.markersize": 6,
    "legend.frameon": False,
    "legend.loc": "upper right",
    "legend.handlelength": 1.5,
    "legend.handletextpad": 0.5,
    "legend.labelspacing": 0.5,
    "legend.columnspacing": 1.5,
    "legend.borderaxespad": 0.5,
    "legend.borderpad": 0.5,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.01,
    "savefig.transparent": False,
    "pdf.fonttype": 42,
    "pdf.compression": 9,
    "pgf.texsystem": "pdflatex",
    "pgf.preamble": r"\usepackage{amsmath}",
    "pgf.rcfonts": False,

})


def load_data(*files):
    data = dict()
    for file in files:
        with open(file, "r") as f:
            lines = f.readlines()
        
        header = lines[0].strip().split(',')
        lines = [line.strip().split(',') for line in lines[1:]]
        runs = {k: [] for k in header}
        
        for line in lines:
            for i, key in enumerate(header):
                runs[key].append(float(line[i].replace('"', '')))

        data[file] = runs
    return data



if __name__ == "__main__":

    # file = "results/boxes_dataset_advanced/boxes_advanced_test_loss.csv" 
    file = "results/boxes_dataset_advanced/boxes_advanced_test_exact_match.csv"
    runs = load_data(file)[file]

    # colors = {
    #     "2 layers": "#0063e6",
    #     "3 layers": "#00d2e6",
    #     "4 layers": "#6fe600",
    #     "5 layers": "#dee600",#"#00cf3b",
    #     "2 layers (ours)": "#e01502",
    # }
    colors = {
        "2 layers": "#d9d902",
        "3 layers": "#00d466",
        "4 layers": "#0dbaff",
        "5 layers": "#0d62ff",
        "2 layers (ours)": "#e01502",
    }
    
    labels = dict()
    X = runs['"train/global_step"']
    for header, run in runs.items():
        header = header.replace('"', '')
        
        if not header.endswith(" - eval/exact_match"):
            continue
        
        n_layers = int(header.split('-')[0].split('L=')[-1])
        is_custom = 'custom' in header
        label = f"{n_layers} layers" + (" (ours)" if is_custom else "")

        color = colors[label]
        marker = '-' if is_custom else '--'

        mean = np.array(run) * 100
        mini = np.array(runs[f'"{header}__MIN"']) * 100
        maxi = np.array(runs[f'"{header}__MAX"']) * 100
        
        plt.plot(X, mean, label=label, color=color, linestyle=marker)
        plt.fill_between(X, mini, maxi, color=color, alpha=0.15, linewidth=0)
            
    plt.xlabel("Steps")
    plt.ylabel("Exact match")
    plt.grid(alpha=0.1)

    # plt.legend(loc="lower right", ncol=1)

    # sort legend by label
    handles, labels = plt.gca().get_legend_handles_labels()
    # order = [0, 1, 2, 3, 4]
    # order = [0, 1, 2, 4, 3]
    order = [4, 2, 1, 0, 3]
    plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order], loc="lower right", ncol=1)

    plt.xlim(1, 25000)
    plt.ylim(0, 103)

    plt.tight_layout()
    plt.savefig('results/boxes_advanced_exact_match.pdf')