
import matplotlib.pyplot as plt
import numpy as np


plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.size": 12,
    "figure.figsize": (6, 4),
    "axes.titlesize": 12,
    "axes.labelsize": 14,
    "legend.fontsize": 12,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.direction": "inout",
    "ytick.direction": "inout",
    "lines.linewidth": 2.0,
    "lines.markersize": 6,
    "legend.frameon": False,
    "legend.loc": "upper right",
    "legend.handlelength": 1.5,
    "legend.handletextpad": 0.5,
    "legend.labelspacing": 0.5,
    "legend.columnspacing": 1.5,
    "legend.borderaxespad": 0.5,
    "legend.borderpad": 0.5,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.01,
    "savefig.transparent": False,
    "pdf.fonttype": 42,
    "pdf.compression": 9,
    "pgf.texsystem": "pdflatex",
    "pgf.preamble": r"\usepackage{amsmath}",
    "pgf.rcfonts": False,

})



def str_to_value(s: str):
    if s.endswith('%'):
        return float(s[:-1]) / 100
    else:
        return eval(s)


def load_runs(*files):
    runs = []
    for file in files:
        with open(file, "r") as f:
            for line in f:
                line = line.strip('\n')
                if len(line) == 0:
                    continue

                if line.startswith("===="):
                    run = {'config': {}, 'metrics': {}}
                    runs.append(run)

                elif line.startswith("  ") and ':' in line:
                    key, value = line.split(':', 1)
                    key = key.strip()
                    value = value.strip()
                    assert key not in run['config']
                    run['config'][key] = str_to_value(value)

                elif line.startswith("Epoch "):
                    parts = line.replace(':', ',').split(',')
                    for i in range(1, len(parts), 2):
                        key = parts[i].strip()
                        value = parts[i+1].strip()
                        if key not in run['metrics']:
                            run['metrics'][key] = []
                        run['metrics'][key].append(str_to_value(value))


    return runs



if __name__ == "__main__":

    all_runs = load_runs(
        "toy_task/results/results1.txt",
        "toy_task/results/results2.txt",
    )

    # colors = {
    #     "1 layer": "#0063e6",
    #     "2 layers": "#00d2e6",
    #     "3 layers": "#00e68a",
    #     "4 layers": "#6fe600",
    #     "5 layers": "#dee600",#"#00cf3b",
    #     "1 layer (ours)": "#e01502",
    # }
    colors = {
        "1 layer": "#d9d902",
        "2 layers": "#00d466",
        "3 layers": "#0dbaff",
        "4 layers": "#0d62ff",
        "5 layers": "#b30bd9",
        "1 layer (ours)": "#e01502",
    }
    
    labels = dict()
    for run in all_runs:

        label = f"{run['config']['layers']} layer{'s' if run['config']['layers'] > 1 else ''}"
        is_custom = 'custom_attention' in run['config'] and run['config']['custom_attention']
        if is_custom:
            label += " (ours)"

        if label not in labels:
            labels[label] = []
        
        labels[label].append(run)
    
    for label, runs in labels.items():
        color = colors[label]
        is_custom = 'custom_attention' in runs[0]['config'] and runs[0]['config']['custom_attention']
        marker = '-' if is_custom else '--'

        mean = 0
        mini = 100
        maxi = 0
        for run in runs:
            acc = np.array(run['metrics']['accuracy']) * 100
            mean = mean + acc
            mini = np.minimum(mini, acc)
            maxi = np.maximum(maxi, acc)
        mean = mean / len(runs)
        
        epochs = list(range(1, len(mean)+1))
        plt.plot(epochs, mean, label=label, color=color, linestyle=marker)
        plt.fill_between(epochs, mini, maxi, color=color, alpha=0.15, linewidth=0)
            
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend(loc="lower right", ncol=2)
    plt.grid(alpha=0.1)

    plt.xlim(1, 30)
    plt.ylim(0, 103)

    plt.tight_layout()
    plt.savefig('toy_task/results/plot.pdf')