'''
evaluate metric for text2text generation tasks
'''
from sklearn.metrics import matthews_corrcoef, f1_score
from scipy.stats import pearsonr, spearmanr

# entry function for each tasks
def evaluate_task(prediction, output_sequences, task, tasks=None):
    if task == 'cola':
        return evaluate_cola(prediction, output_sequences)
    elif task == 'sts-b':
        return evaluate_stsb(prediction, output_sequences)
    elif task == "multitask":
        return evaluate_multitask(prediction, output_sequences, tasks)
    else:
        return evaluate_accuracy(prediction, output_sequences)

# entry function for multitask learning
def evaluate_multitask(prediction, output_sequences, tasks):
    prediction_tasks = {}
    output_sequences_tasks = {}
    # classify the results
    for i in range(len(prediction)):
        task = tasks[i]
        if task not in prediction_tasks:
            prediction_tasks[task] = [prediction[i]]
            output_sequences_tasks[task] = [output_sequences[i]]
        else:
            prediction_tasks[task].append(prediction[i])
            output_sequences_tasks[task].append(output_sequences[i])
    results = {}
    for key in prediction_tasks:
        results[key] = {"result": evaluate_task(prediction_tasks[key], output_sequences_tasks[key], key), "count": len(prediction_tasks[key])}

    result = 0.0
    count_total = 0
    for key in results:
        result += results[key]['result'] * results[key]['count']
        count_total += results[key]['count']
    # return average result
    return result / count_total

# eval metric for cola: matthews_corrcoef
def evaluate_cola(prediction, output_sequences):
    mcc = matthews_corrcoef(output_sequences, prediction)
    return mcc

# eval metric for stsb: pearsonr & spearmanr
def evaluate_stsb(prediction, output_sequences):
    prediction = [float(i) for i in prediction]
    output_sequences = [float(i) for i in output_sequences]
    pearsonr_corr = pearsonr(output_sequences, prediction)[0]
    print("pearson r: %.4f" % pearsonr_corr)
    spearmanr_corr = spearmanr(output_sequences, prediction)[0]
    print("spearman r: %.4f" % spearmanr_corr)
    return (pearsonr_corr + spearmanr_corr) / 2

# eval metric for other tasks: accuracy
def evaluate_accuracy(prediction, output_sequences):
    accuracy = 0.0
    for i in range(len(prediction)):
        if output_sequences[i] == prediction[i]:
            accuracy += 1
    accuracy = accuracy / len(prediction)
    return accuracy
