import argparse
from generator.vanilla_trainer import train
from utils.util import load_config
from datasets import load_dataset
from tqdm import tqdm
from collections import Counter
 
def bioasq_metrics(args):
    dataset = load_dataset('json', data_files=args.solution_dir, split=args.split)
    print(dataset[0])
 
    statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0}
    if args.style == 'interleave' or args.style == 'repeat':
        if args.style == 'interleave':
            index_list = list(range(0, len(dataset), args.k))
        elif args.style == 'repeat':
            index_list = list(range(0, int(len(dataset)/args.k)))
        for idx in tqdm(index_list):
            sample = dataset[idx]
            prediction, label = bioasq_judge(sample['answer'], sample['generation'])
            # if not '#### ' in prediction:
            #     statistics["incomplete"] += 1
            # else:
            if label == 1: #answer.lower().strip() == prediction.lower().strip():
                statistics["correct"] += 1
            else:
                statistics["incorrect"] += 1
            statistics["total"] += 1
    elif args.style == 'first':
        question_set = []
        for idx in tqdm(range(len(dataset))):
            sample = dataset[idx]
            if not sample['question'] in question_set:
                prediction, label = bioasq_judge(sample['answer'], sample['generation'])
                if label == 1:
                    statistics["correct"] += 1
                else:
                    statistics["incorrect"] += 1
                statistics["total"] += 1
                question_set.append(sample['question'])
    elif args.style == 'self-consistency':
        answer_set = {}
        question_set = {}
        for idx in tqdm(range(len(dataset))):
            sample = dataset[idx]
            if not sample['question'] in question_set:
                question_set[sample['question']] = []
            prediction, label = bioasq_judge(sample['answer'], sample['generation'])
            if '\n#### ' in prediction:
                pred_answer = prediction.split('\n#### ')[-1][:-1]
            else:
                pred_answer = ''
            answer = sample['answer'].split('\n#### ')[-1]
            if '.' in answer:
                answer = answer.split('.')[0]
            answer_set[sample['question']] = answer
            if not sample['question'] in question_set:
                question_set[sample['question']] = [pred_answer]
            else:
                question_set[sample['question']].append(pred_answer)
        for question in list(question_set.keys()):
            answer = answer_set[question]
            prediction_list = question_set[question]
            counts = Counter(prediction_list)
            most_common_prediction = counts.most_common(1)
            prediction = most_common_prediction[0][0]
            if answer.lower() in prediction.lower():
                statistics['correct'] += 1
            else:
                statistics['incorrect'] += 1
            statistics['total'] += 1
    # print(statistics)
    return statistics

def bioasq_judge(answer, prediction):
    answer = answer.split('\n#### ')[-1]
    if '\n#### ' in prediction:
        prediction = prediction.split('\n#### ')
    else:
        prediction = prediction.split('####')
    if len(prediction) == 1:
        # incomplete
        return prediction[0], 0
    else:
        pred_answer = prediction[1]
        if len(pred_answer) == 0:
            pred_answer = 'EMPTY'
        while pred_answer[0] == '\n' or pred_answer[0] == ' ':
            pred_answer = pred_answer[1:]
            if len(pred_answer) == 0:
                pred_answer = 'EMPTY'
        if '\n' in pred_answer:
            pred_answer = pred_answer.split('\n')[0]
        if '.' in pred_answer:
            pred_answer = pred_answer.split('.')[0]
        if ',' in pred_answer:
            pred_answer = pred_answer.split(',')[0]
        pred_answer = pred_answer.replace('\n', '').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '').replace(',', '').replace('.', '')
        if '.' in answer:
            answer = answer.split('.')[0]
        prediction = prediction[0] + '\n#### ' + pred_answer + '.'
        # print(answer, pred_answer)
        # input()
        if answer.lower() in pred_answer.lower():
            return prediction, 1
        else:
            return prediction, 0