
import matplotlib.pyplot as plt
import numpy as np


plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.size": 12,
    "figure.figsize": (6, 4),
    "axes.titlesize": 12,
    "axes.labelsize": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.direction": "inout",
    "ytick.direction": "inout",
    "lines.linewidth": 2.0,
    "lines.markersize": 6,
    "legend.frameon": False,
    "legend.loc": "upper right",
    "legend.handlelength": 1.5,
    "legend.handletextpad": 0.5,
    "legend.labelspacing": 0.5,
    "legend.columnspacing": 1.5,
    "legend.borderaxespad": 0.5,
    "legend.borderpad": 0.5,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.01,
    "savefig.transparent": False,
    "pdf.fonttype": 42,
    "pdf.compression": 9,
    "pgf.texsystem": "pdflatex",
    "pgf.preamble": r"\usepackage{amsmath}",
    "pgf.rcfonts": False,

})



def str_to_value(s: str):
    if s.endswith('%'):
        return float(s[:-1]) / 100
    else:
        return eval(s)


def load_runs(*files):
    runs = []
    for file in files:
        with open(file, "r") as f:
            for line in f:
                line = line.strip('\n')
                if len(line) == 0:
                    continue

                if line.startswith("===="):
                    run = {'config': {}, 'metrics': {}}
                    runs.append(run)

                elif line.startswith("  ") and ':' in line:
                    key, value = line.split(':', 1)
                    key = key.strip()
                    value = value.strip()
                    assert key not in run['config']
                    run['config'][key] = str_to_value(value)

                elif line.startswith("Epoch "):
                    parts = line.replace(':', ',').split(',')
                    for i in range(1, len(parts), 2):
                        key = parts[i].strip()
                        value = parts[i+1].strip()
                        if key not in run['metrics']:
                            run['metrics'][key] = []
                        run['metrics'][key].append(str_to_value(value))
                
                elif line.startswith("Model has"):  # Model has 3,292,672 parameters
                    parts = line.split()
                    run['config']['n_params'] = int(parts[2].replace(',', ''))



    return runs



if __name__ == "__main__":

    all_runs = load_runs(
        "toy_task/results/results1.txt",
        "toy_task/results/results2.txt",
    )
    
    labels = dict()
    for run in all_runs:

        n_layers = f"{run['config']['layers']}"
        label = f"{run['config']['layers']} layer{'s' if run['config']['layers'] > 1 else ''}"
        is_custom = 'custom_attention' in run['config'] and run['config']['custom_attention']
        
        if is_custom:
            label += " (ours)"

        if label not in labels:
            labels[label] = []
        
        labels[label].append(run)
    
    for label, runs in labels.items():
        is_custom = 'custom_attention' in runs[0]['config'] and runs[0]['config']['custom_attention']
        n_layers = runs[0]['config']['layers']
        n_params = runs[0]['config']['n_params']

        attention = ""
        if is_custom:
            attention = "GPAL (ours)"
        elif n_layers == 3:
            attention = "Standard"

        losses = []
        accs = []
        for run in runs:
            acc = np.array(run['metrics']['accuracy'][-1]) * 100
            accs.append(acc)
            losses.append(np.array(run['metrics']['loss'][-1]))
        
        print(f"{attention.ljust(12)} & {n_layers} & {round(n_params / 1e6, 1)}M & {np.mean(losses):.2f} $\pm$ {np.std(losses):.2f} & {np.mean(accs):.1f}\% $\pm$ {np.std(accs):.2f} \\\\")