from transformers import PegasusTokenizer
from transformers import PegasusForConditionalGeneration
import torch
from nltk.tokenize import sent_tokenize
import numpy as np
from scipy.stats import pearsonr, spearmanr, kendalltau

task = "senti"

with open(task + '_cons.txt', 'r') as f_in:
    data, prefix_list, human_score = [], [], []
    for line in f_in.readlines():
        line_split = line.strip().split('\t')
        assert len(line_split) == 3
        data.append(line_split[-3])
        prefix_list.append(line_split[-2])
        human_score.append(float(line_split[-1]))

with open('iwf_full.txt', 'r') as f_in_iwf:
    iwf_score = [float(line.strip()) for line in f_in_iwf.readlines()]

print('number of total data: ', len(data))

model_name = './pegasus'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)


def get_batched_data(sent_str, final_prefix):
    assert sent_str.index(final_prefix) == 0
    tgt_text = [sent_str[len(final_prefix):], final_prefix]
    src_text = [final_prefix + ' <mask_1>', '<mask_1> ' + sent_str[len(final_prefix):]]
    return src_text, tgt_text, final_prefix


total_score = []
for data_id, data_ele in enumerate(data):
    src_text_list, tgt_text_list, selected_prefix = get_batched_data(data_ele, prefix_list[data_id])

    score = []
    tgt_score = []
    score_bias = []

    for src_text_ele, tgt_text_ele in zip(src_text_list, tgt_text_list):
        model.eval()
        src_text, tgt_text = [src_text_ele], [tgt_text_ele]

        with torch.no_grad():
            src_ids = tokenizer(src_text, truncation=True, padding='do_not_pad', return_tensors="pt").to(device)
            tgt_ids = tokenizer(tgt_text, truncation=True, padding='do_not_pad', add_special_tokens=False)
            tgt_score.append(max([iwf_score[token_id] for token_id in tgt_ids['input_ids'][0]]))

            tgt_ids = tokenizer(tgt_text, truncation=True, padding='do_not_pad', add_special_tokens=False,
                                return_tensors="pt").to(device)
            loss = model(input_ids=src_ids['input_ids'], labels=tgt_ids['input_ids'])[0]
            score.append(-loss.detach().cpu().numpy())

    if sum(tgt_score) > 0:
        tmp_score = np.dot(score, tgt_score) / sum(tgt_score)
    else:
        tmp_score = np.mean(score)

    total_score.append(tmp_score)

    if data_id % 10 == 0:
        print('processing data: ', data_id)

print('consistency pearson: ', pearsonr(total_score, human_score))
print('consistency spearman: ', spearmanr(total_score, human_score))
print('consistency kendall: ', kendalltau(total_score, human_score))
