import os
import glob
import sys
import json
import numpy as np

import pandas as pd
        

def args_to_key(dict, important_keys=['model_name', 'lr', 'mean_sample', 'do_sample','config','do_lower','checkpoint','interp_percent']):
    dict = {k: dict[k] for k in important_keys if k in dict}
    if 'lr' in dict:
        dict['lr'] = int(dict['lr'])

    if 'checkpoint' in dict:
        dict['checkpoint'] = os.path.basename(os.path.dirname(dict['checkpoint']))

    key = json.dumps(dict, sort_keys=True)
    return key

def fuse_results(result_parent_dir, suffixes=['.informal_eval', '.formal_eval']):
    
    key_to_results = {}
    for folder_path in glob.glob(os.path.join(result_parent_dir, '*')):
        if not os.path.isdir(folder_path):
            continue
        args_fname = os.path.join(folder_path, 'args.json')
        if not os.path.exists(args_fname):
            args_fname = os.path.join(folder_path, 'hparams.json')
            assert os.path.exists(args_fname)

        with open(args_fname, 'r') as f:
            args = json.load(f)

        key = args_to_key(args)

        if key not in key_to_results:
            key_to_results[key] = {} #{'args': args}

        for suffix in suffixes:
            result_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(suffix)]
            assert len(result_files) <= 1, result_files
            if len(result_files) == 0:
                continue
            with open(result_files[0], 'r') as f:
                results = json.load(f)

            assert suffix not in key_to_results[key]
            key_to_results[key][suffix] = results

    return key_to_results





if __name__ == '__main__':
    folder_list = ['chatgpt_fixed_prompt','baseline_copy','mix_match', 'paraguide','sft_v1_outputs','sft_v2_outputs']
    # folder_list = ['sft_v2_outputs_interp_v2']
    combined_results = {}

    for parent_folder in folder_list:
        key_to_results = fuse_results(parent_folder)

        combined_results.update({f'{parent_folder}_{key}': value for key, value in key_to_results.items()})

    # metrics = ['accuracy', 'holdout_accuracy','similarity',  'cola', 'holdout_joint_gm', 'perplexity_median']
    metrics = ['holdout_accuracy','similarity',  'cola', 'holdout_joint_gm', 'perplexity_median']

    tasks = sorted(set().union(*[set(combined_results[key].keys()) for key in combined_results])) 
    print(tasks)
    combined_metrics = {}
    for experiment_config in combined_results:
        combined_metrics[experiment_config] = {}
        for task in tasks:
            combined_metrics[experiment_config][task] = {}
            for metric in metrics:
                combined_metrics[experiment_config][task][metric] = '{:0.2f}'.format(round(combined_results[experiment_config][task]['decoded'][metric],2))
    print(combined_metrics)
        
            

    # create dataframe
    columns = ['method'] + metrics
    # for metric in metrics:
    #     for task in tasks:
    #         columns.append(met

    df = pd.DataFrame(columns=columns)
    for experiment_config in combined_metrics:
        row = [experiment_config]
        for metric in metrics:
            metric_results = []
            for task in tasks:
                metric_results.append(combined_metrics[experiment_config][task][metric])
            mean = round(np.mean([float(x) for x in metric_results]),2)
            row.append(f'{mean:.2f} (' + ', '.join(metric_results) + ')')

        df.loc[len(df)] = row

        print(' & '.join(row) + ' \\\\')
            

        # df.loc[len(df)] = row

    df.to_csv('results.csv', index=False)