from transformers import PegasusTokenizer
from transformers import PegasusForConditionalGeneration
import torch
import numpy as np
from scipy.stats import pearsonr, spearmanr, kendalltau

task = "senti"

with open(task + '_ar.txt', 'r') as f_in:
    data, label, rel_score = [], [], []
    for line in f_in.readlines():
        data.append(' '.join(line.strip().split('\t')[:-2]))
        label.append(int(line.strip().split('\t')[-2]))
        rel_score.append(float(line.strip().split('\t')[-1]))

with open('prompt_list_' + task + '.txt', 'r') as f_in_pr:
    prompt_str = [line.strip() for line in f_in_pr.readlines()]

print('number of total data: ', len(data))
print('number of prompts: ', len(prompt_str))

model_name = './pegasus'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)

if task == "senti":
    num_label = 2
    verbal = [['bad', 'good'], ['negative', 'positive'], ['terrible', 'great']]
else:
    num_label = 4
    verbal = [['computers', 'politics', 'religion', 'science']]

pred_score, result_score = [], []
for data_id, data_ele in enumerate(data):
    match_score = []
    prob_score = []
    for src_text_ele in prompt_str:
        src_text = [src_text_ele.replace('<gen_result>', data_ele).replace('<mask_token>', '<mask_1>')]
        src_ids = tokenizer(src_text, truncation=True, padding='do_not_pad', return_tensors="pt").to(device)
        for tgt_text_ele in verbal:
            lm_score = []
            prob = []
            for verbal_id in range(num_label):
                model.eval()
                tgt_text = [tgt_text_ele[verbal_id]]
                with torch.no_grad():
                    tgt_ids = tokenizer(tgt_text, truncation=True, padding='do_not_pad', add_special_tokens=False,
                                        return_tensors="pt").to(device)
                    loss = model(input_ids=src_ids['input_ids'], labels=tgt_ids['input_ids'])[0]
                    lm_score.append(torch.exp(-loss).detach().cpu().numpy())
                    prob.append(torch.exp(-loss).detach().cpu().numpy())
            match_score.append(sum(lm_score))
            prob_dis = [prob[idx] / sum(prob) for idx in range(len(prob))]
            prob_score.append(prob_dis)

    total_score = np.zeros(num_label)
    match_score = [match_score[idx] / sum(match_score) for idx in range(len(match_score))]

    for idx in range(len(match_score)):
        total_score += match_score[idx] * np.array(prob_score[idx])

    pred_real_score = total_score[label[data_id]]
    result_score.append(pred_real_score)

    if data_id % 10 == 0:
        print('processing data: ', data_id)

print('AR pearson: ', pearsonr(result_score, rel_score))
print('AR spearman: ', spearmanr(result_score, rel_score))
print('AR kendall: ', kendalltau(result_score, rel_score))
