import json
import argparse
import os.path

parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str)
args = parser.parse_args()

task_exist = 0

## MMLU
if os.path.isfile(f'{args.file}-mmlu.json'):
    with open(f'{args.file}-mmlu.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']

    sum = 0
    count = 0
    for key in json_dict.keys():
        #print(key)
        #print(json_dict[key]['acc'])
        sum += json_dict[key]['acc']
        count += 1

    print("average score: ", sum / count)
    print("count: ", count)
    mmlu_avg_score = round((sum / count) * 100, 2)
    task_exist += 1
else:
    mmlu_avg_score = 0

## hellaswag
if os.path.isfile(f'{args.file}-hellaswag.json'):
    with open(f'{args.file}-hellaswag.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']
    hellaswag_avg_score = round(json_dict["hellaswag"]["acc_norm"] * 100, 2)
    task_exist += 1
else:
    hellaswag_avg_score = 0

## arc-challenge
if os.path.isfile(f'{args.file}-arc.json'):
    with open(f'{args.file}-arc.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']
    arc_avg_score = round(json_dict["arc_challenge"]["acc_norm"] * 100, 2)
    task_exist += 1
else:
    arc_avg_score = 0

## truthfulQA
if os.path.isfile(f'{args.file}-truthfulqa.json'):
    with open(f'{args.file}-truthfulqa.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']
    truthful_avg_score = round(json_dict["truthfulqa_mc"]["mc2"] * 100, 2)
    task_exist += 1
else:
    truthful_avg_score = 0

## winogrande
if os.path.isfile(f'{args.file}-winogrande.json'):
    with open(f'{args.file}-winogrande.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']
    winogrande_avg_score = round(json_dict["winogrande"]["acc"] * 100, 2)
    task_exist += 1
else:
    winogrande_avg_score = 0

## gsm8k
if os.path.isfile(f'{args.file}-gsm8k.json'):
    with open(f'{args.file}-gsm8k.json', 'r') as f:
        json_data = json.load(f)
    json_dict = json_data['results']
    gsm8k_avg_score = round(json_dict["gsm8k"]["acc"] * 100, 2)
    task_exist += 1
else:
    gsm8k_avg_score = 0

print("")
if task_exist:
    print(f'm/h/a/t/w/g : {mmlu_avg_score} / {hellaswag_avg_score} / {arc_avg_score} / {truthful_avg_score} / {winogrande_avg_score} / {gsm8k_avg_score}')
    all_avg = (mmlu_avg_score + hellaswag_avg_score + arc_avg_score + truthful_avg_score + winogrande_avg_score + gsm8k_avg_score) / task_exist
    print(f'avg: {all_avg}')
else:
    print("error during few-shot")