import json
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import Dict
import math

def get_max_average_diff(setup: Dict):

    sentences = list(setup.keys())
    diffs = list()
    diffs.append(math.fabs(setup[sentences[0]] - setup[sentences[1]]))
    diffs.append(math.fabs(setup[sentences[0]] - setup[sentences[2]]))
    diffs.append(math.fabs(setup[sentences[0]] - setup[sentences[3]]))
    diffs.append(math.fabs(setup[sentences[0]] - setup[sentences[4]]))
    diffs = np.array(diffs)
    return diffs.max(), diffs.min(), diffs.mean()


def main(input_path: str, output_path: str):

    with open(input_path, 'r') as f:
        results = json.load(f)

    parsed_results = defaultdict(lambda: list())
    for sent_id in results:
        for mod_name in results[sent_id]:
            max_diff, min_diff, average_diff = get_max_average_diff(results[sent_id][mod_name])
            parsed_results['sent_id'].append(sent_id)
            parsed_results['model'].append(mod_name)
            parsed_results['max_diff'].append(max_diff)
            parsed_results['min_diff'].append(min_diff)
            parsed_results['average_diff'].append(average_diff)

    df = pd.DataFrame.from_dict(parsed_results)
    df.to_csv(output_path)

    global_parsed_results = defaultdict(lambda : list())
    for sent_id in results:
        df_sent_id = df[df.sent_id == sent_id]
        global_parsed_results['sent_id'].append(sent_id)
        global_parsed_results['max_diff'].append(df_sent_id['max_diff'].mean())
        global_parsed_results['min_diff'].append(df_sent_id['min_diff'].mean())
        global_parsed_results['average_diff'].append(df_sent_id['average_diff'].mean())

    df = pd.DataFrame.from_dict(global_parsed_results)
    df.to_csv(output_path.replace('.csv', '_global.csv'))

main('/results/prolific/llm_pretest/bert/pretest_multiberts.json',
     '/results/prolific/llm_pretest/bert/pretest_multiberts_parsed.csv')