'''
evaluate metric for normal classification tasks
'''
from sklearn.metrics import matthews_corrcoef, f1_score
from scipy.stats import pearsonr, spearmanr
from rouge import Rouge
import re

# entry function for each tasks
def evaluate_t2t_task(prediction, output_sequences, task, tasks=None):
    if task == 'cola':
        return evaluate_t2t_cola(prediction, output_sequences)
    elif task == 'sts-b':
        return evaluate_t2t_stsb(prediction, output_sequences)
    elif task == "multitask":
        return evaluate_t2t_multitask(prediction, output_sequences, tasks)
    elif task == 'cnndm':
        return evaluate_t2t_rouge(prediction, output_sequences)
    else:
        return evaluate_t2t_accuracy(prediction, output_sequences)

# entry function for multitask learning
def evaluate_t2t_multitask(prediction, output_sequences, tasks):
    prediction_tasks = {}
    output_sequences_tasks = {}
    # classify the results
    for i in range(len(prediction)):
        task = tasks[i]
        if task not in prediction_tasks:
            prediction_tasks[task] = [prediction[i]]
            output_sequences_tasks[task] = [output_sequences[i]]
        else:
            prediction_tasks[task].append(prediction[i])
            output_sequences_tasks[task].append(output_sequences[i])
    results = {}
    for key in prediction_tasks:
        results[key] = {"result": evaluate_t2t_task(prediction_tasks[key], output_sequences_tasks[key], key), "count": len(prediction_tasks[key])}

    result = 0.0
    count_total = 0
    for key in results:
        result += results[key]['result'] * results[key]['count']
        count_total += results[key]['count']
    # return average result
    return result / count_total

# eval metric for cola: matthews_corrcoef
def evaluate_t2t_cola(prediction, output_sequences):
    mcc = matthews_corrcoef(output_sequences, prediction)
    return mcc

# eval metric for stsb: pearsonr & spearmanr
def evaluate_t2t_stsb(prediction, output_sequences):
    # convert the string results into the form of input of pearsonr() & spearmanr()
    prediction = [float(re.findall("\d+\.?\d*", i)[0]) if len(re.findall("\d+\.?\d*", i)) > 0 else 0.0 for i in prediction]
    output_sequences = [float(re.findall("\d+\.?\d*", i)[0]) for i in output_sequences]
    pearsonr_corr = pearsonr(output_sequences, prediction)[0]
    print("pearson r: %.4f" % pearsonr_corr)
    spearmanr_corr = spearmanr(output_sequences, prediction)[0]
    print("spearman r: %.4f" % spearmanr_corr)
    return (pearsonr_corr + spearmanr_corr) / 2

# eval metric for other tasks: accuracy
# only exact match is correct
def evaluate_t2t_accuracy(prediction, output_sequences):
    accuracy = 0.0
    for i in range(len(prediction)):
        if output_sequences[i] == prediction[i]:
            accuracy += 1
    accuracy = accuracy / len(prediction)
    return accuracy

def evaluate_t2t_rouge(prediction, output_sequences):
    rouge = Rouge()
    rouge_1, rouge_2, rouge_l_f1, rouge_l_p, rouge_l_r = 0, 0, 0, 0, 0
    for i in range(len(prediction)):
        output_text = prediction[i]
        label_text = output_sequences[i]
        try:
            result = rouge.get_scores([output_text], [label_text], avg=True)  # 取一个 batch 的平均
        except:
            result = rouge.get_scores(['summary' + output_text], [label_text], avg=True)
        rouge_1 += result['rouge-1']['f']
        rouge_2 += result['rouge-2']['f']
        rouge_l_f1 += result['rouge-l']['f']
        rouge_l_p += result['rouge-l']['p']
        rouge_l_r += result['rouge-l']['r']
    print('AVG Rouge_1: {}, Rouge_2: {}, Rouge_l_f1: {}, Rouge_l_p: {}, Rouge_l_r: {}'.format(rouge_1 / len(prediction), rouge_2 / len(prediction), rouge_l_f1 / len(prediction), rouge_l_p / len(prediction), rouge_l_r / len(prediction)))
    return {'Rouge_1': rouge_1 / len(prediction), 'Rouge_2': rouge_2 / len(prediction), 'Rouge_l_f1': rouge_l_f1 / len(prediction), 'Rouge_l_p': rouge_l_p / len(prediction), 'Rouge_l_r': rouge_l_r / len(prediction)}
    # return rouge_1 / len(prediction)

if __name__ =="__main__":
    import json 
    import os
    predictions = []
    targets = []
    world_size = 2
    for i in range(world_size):
        with open(os.path.join("../output/t5_base/round2/p20/seed42/Target_100shot/cnndm/prompt_tuning_bikt_ablation", "result_rank{}.json".format(i)), encoding='utf-8') as f:
            result = json.load(f)
        for item in result:
            predictions.append(item['prediction'])
            targets.append(item['target'])
    evaluate_t2t_rouge(predictions, targets)