import os
import re

import matplotlib.pyplot as plt
import numpy as np

plt.style.use('seaborn-whitegrid')

for dataset, directory in zip(['eurlex', 'eurlex100', 'eurlex500'], ['BATCH_64_SMALL_MORE_GROUPS', 'BATCH_64_GROUPED_LWAN', 'BATCH_64_LWAN']):
    BASE_DIR = f'/home/iliasc/temporal-wilds/logs/{dataset}/{directory}'
    ALGO = 'IRM'

    for idx in range(1, 4):
        algo_dir = os.path.join(BASE_DIR, f'{ALGO}_{idx}')
        if os.path.exists(os.path.join(BASE_DIR, f'{ALGO}_{idx}')):
                try:
                    filename = os.path.join(BASE_DIR, f'{ALGO}_{idx}', 'log.txt')
                    with open(filename) as file:
                        text = file.read()
                    # print(re.search('\nSd lambda: ([\d\.]+)', text).group(1))
                    # losses = [float(m.group(1)) for m in re.finditer('\nobjective: ([\d\.]+)', text)]
                    # penalties = [float(m.group(1)) for m in re.finditer('\npenalty: ([\d\.]+)', text)]
                    epochs = [m.group(1) for m in re.finditer(r'\nEpoch \[(\d+)\]\:', text)]
                    print(epochs[-1])

                    # plt.plot(np.arange(len(losses)), losses, color='blue', label='loss')
                    # plt.plot(np.arange(len(penalties)), penalties, color='red', label='penalty')
                    # plt.legend()
                    # plt.savefig(f'{dataset}_{ALGO}irm_penalties.png')
                    # plt.clf()
                except:
                    break
