from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
from nltk.tokenize import sent_tokenize
import numpy as np
from scipy.stats import pearsonr, spearmanr, kendalltau

task = "senti"

with open(task + '_coh.txt', 'r') as f_in:
    data, human_score = [], []
    for line in f_in.readlines():
        line_split = line.strip().split('\t')
        assert len(line_split) == 2
        data.append(sent_tokenize(line_split[-2]))
        human_score.append(float(line_split[-1]))

with open('iwf_full.txt', 'r') as f_in_iwf:
    iwf_score = [float(line.strip()) for line in f_in_iwf.readlines()]

print('number of total data: ', len(data))

model_name = './pegasus'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)


def get_batched_data(sent_list):
    src_text, tgt_text = [], []
    for idx in range(len(sent_list)):
        tgt_text.append(sent_list[idx])
        src_text.append(' '.join(sent_list[:idx]) + ' <mask_1> ' + ' '.join(sent_list[idx+1:]))
    return src_text, tgt_text


total_score = []
for data_id, data_ele in enumerate(data):
    src_text_list, tgt_text_list = get_batched_data(data_ele)
    score = []
    tgt_score = []
    for src_text_ele, tgt_text_ele in zip(src_text_list, tgt_text_list):
        model.eval()
        src_text, tgt_text = [src_text_ele], [tgt_text_ele]
        with torch.no_grad():
            batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(device)
            labels = tokenizer(tgt_text, truncation=True, padding='longest', return_tensors="pt").to(device)
            tgt_score.append(max([iwf_score[token_id] for token_id in labels['input_ids'].cpu().numpy()[0]]))
            loss = model(input_ids=batch['input_ids'], labels=labels['input_ids'])[0]
            score.append(-loss.detach().cpu().numpy())

    if sum(tgt_score) > 0:
        tmp_score = np.dot(score, tgt_score) / sum(tgt_score)
    else:
        tmp_score = np.mean(score)

    total_score.append(tmp_score)

    if data_id % 10 == 0:
        print('processing data: ', data_id)

print('coherence pearson: ', pearsonr(total_score, human_score))
print('coherence spearman: ', spearmanr(total_score, human_score))
print('coherence kendall: ', kendalltau(total_score, human_score))
