import os
import glob
import sys
import json
import numpy as np

import pandas as pd

import matplotlib.pyplot as plt
        

key_rename = {
    '.informal_eval': 'informal',
    '.formal_eval': 'formal',
    'holdout_accuracy': 'Style Accuracy',
    'similarity': 'Meaning Preservation',
}

def args_to_key(dict):
    return dict['interp_percent']

def fuse_results(result_parent_dir, suffixes=['.informal_eval', '.formal_eval']):
    
    key_to_results = {}
    for folder_path in glob.glob(os.path.join(result_parent_dir, '*')):
        if not os.path.isdir(folder_path):
            continue
        args_fname = os.path.join(folder_path, 'args.json')
        if not os.path.exists(args_fname):
            args_fname = os.path.join(folder_path, 'hparams.json')
            assert os.path.exists(args_fname)

        with open(args_fname, 'r') as f:
            args = json.load(f)

        key = args_to_key(args)

        if key not in key_to_results:
            key_to_results[key] = {} #{'args': args}

        for suffix in suffixes:
            result_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(suffix)]
            assert len(result_files) <= 1, result_files
            if len(result_files) == 0:
                continue
            with open(result_files[0], 'r') as f:
                results = json.load(f)

            suffix = key_rename[suffix]


            assert suffix not in key_to_results[key]
            key_to_results[key][suffix] = results

    return key_to_results

def average_over_tasks(results):
    for method in results:
        combined = {}


        for task in results[method]:
            for metric in results[method][task]:
                if metric not in combined:
                    combined[metric] = []
                combined[metric].append(results[method][task][metric])
        
        results[method]['combined'] = {}
        for metric in combined:
            results[method]['combined'][metric] = np.mean(combined[metric])
       


if __name__ == '__main__':
    parent_folder = 'sft_v2_outputs_interp_v2'

    exp_to_results = fuse_results(parent_folder)

    baselines = {
        'GPT-3.5': 
            {
                'formal':{'Style Accuracy': 0.97, 'Meaning Preservation': 0.86},
                'informal':{'Style Accuracy': 0.82, 'Meaning Preservation': 0.87},
            },
        # 'GPT-4':
        #     {
        #         'formal':{'Style Accuracy': 0.99, 'Meaning Preservation': 0.87},
        #         'informal':{'Style Accuracy': 0.91, 'Meaning Preservation': 0.91},
        #     },
        'PGuide (200)':
        {
            'formal':{'Style Accuracy': 0.91, 'Meaning Preservation': 0.61},
            'informal':{'Style Accuracy': 0.96, 'Meaning Preservation': 0.69},
        },
        'M&M (HAM)':
        {
            'formal':{'Style Accuracy': 0.08, 'Meaning Preservation': 0.56},
            'informal':{'Style Accuracy': 0.90, 'Meaning Preservation': 0.57},
        }, 

    }

    average_over_tasks(baselines)

    metrics = ['Style Accuracy','Meaning Preservation']
    tasks = list(exp_to_results[list(exp_to_results.keys())[0]].keys())


    for exp in exp_to_results:
        for task in tasks:
            exp_to_results[exp][task] = {key_rename[k]:v for k,v in exp_to_results[exp][task]['decoded'].items() if k in key_rename and key_rename[k] in metrics}

    average_over_tasks(exp_to_results)

    tasks = ['combined'] + tasks

    percent_step = sorted(exp_to_results.keys())
    task_to_metrics = {task:{metric:[] for metric in metrics} for task in tasks}
    for percent in percent_step:
        for task in tasks:
            for metric in metrics:
                task_to_metrics[task][metric].append(exp_to_results[percent][task][metric])


    import matplotlib.pyplot as plt

    plt.rcParams['font.family'] = 'sans-serif'
    fig, axs = plt.subplots(2, 1, figsize=(5, 8))
    plt.tight_layout(pad=2.0)


    # increase thickness of border
    



    colors = ['blue', 'green', 'red', 'purple']
    linestyles = ['-', '--', ':', '-.']
    markers = ['x','o',  '^', 's']

    for i, metric in enumerate(metrics):
        for spine in axs[i].spines.values():
            spine.set_linewidth(1.5)
        color = colors[0] #i % len(colors)]
        linestyle = linestyles[0] #i % len(linestyles)]
        marker = markers[0] #i % len(markers)]

        axs[i].plot(np.array(percent_step)*100, task_to_metrics['combined'][metric], marker=marker, linestyle=linestyle, color=color)      
        axs[i].set_title(f'{metric}')
        axs[i].title.set_fontsize(12)
        axs[i].grid(True, linestyle='--', alpha=0.5)
        axs[i].tick_params(axis='both', which='major', labelsize=8)

        for model_index, model in enumerate(baselines):
            baseline_color = 'black' #colors[model_index % len(colors)]
            axs[i].axhline(y=baselines[model]['combined'][metric], color=baseline_color, linestyle='--', label=f'{model} {metric}')
            axs[i].text(0.001, baselines[model]['combined'][metric], f'{model}', color=baseline_color, fontsize=8, verticalalignment='bottom')

    # axs[0].set_xticklabels([])
    axs[0].set_xlabel('')
    # axs[1].legend(loc='upper right', bbox_to_anchor=(1.05, 1))

   
    axs[1].set_xlabel('Interpolation towards Target Style (%)')




    plt.savefig(f'combined_metrics_vs_interp_percent.png')
