import numpy as np
import torch
from collections import defaultdict, Counter
import pdb
import json
import os
import argparse
from config import *

N_seeds, N_formats = 5, 3
N_exps = N_seeds * N_formats
MODEL_LIST = [('llama', 7), ('llama', 13), ('mistral', 7)]
TASK_LIST = list(Tasks.keys())

def save_results(fn, RES):
    with open(fn , "w" ) as write:
        json.dump(RES , write)


def model_map(model, size):
    if model == 'llama':
        return f'meta-llama/Llama-2-{size}b-hf'
    else:
        return 'mistralai/Mistral-7B-Instruct-v0.1'


def compact(RES, fn):
    # get avg, std over N_exps
    summary = defaultdict(lambda: defaultdict(str))
    for md, task2accs in RES.items():
        task_avg = []
        for task, accs in task2accs.items():
            try:
                assert len(accs) == N_exps
            except:
                print(fn, md, task, len(accs))
            accs = np.array(accs)
            summary[md][task] = f'{100*accs.mean():.1f} ± {100*accs.std():.1f}'
            task_avg.append(accs.mean())
        summary[md]['Task_Avg'] = f'{np.array(task_avg).mean():.1%}'

    save_results(os.path.join('Summary', f'{fn}_flatten.json'), RES)
    save_results(os.path.join('Summary', f'{fn}_mean_std.json'), summary)
    return summary


def load(fn, out_dir='Summary'):
    with open(os.path.join(out_dir, f'{fn}_flatten.json')) as f:
        RES = json.load(f)
    with open(os.path.join(out_dir, f'{fn}_mean_std.json')) as f:
        summary = json.load(f)
    return RES, summary
