from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
from transformers import Trainer, TrainingArguments
import numpy as np


class StoriumDataset(torch.utils.data.Dataset):
    def __init__(self, encodings_labels_dict):
        self.encodings = encodings_labels_dict["encodings"]

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

class RobertaPredictor:
    def __init__(self, load_path="models/persona/version_7", gpu=None, bc=None):
        self.model = RobertaForSequenceClassification.from_pretrained(load_path)
        self.model.eval()
        self.model.to(f"cuda:{gpu}" if gpu is not None else "cpu")
        self.tokenizer = RobertaTokenizer.from_pretrained(load_path)
        self.trainer = Trainer(
            model = self.model,
            args = TrainingArguments(per_device_eval_batch_size=bc if bc is not None else 8, output_dir='predict_output.txt'),
            tokenizer = self.tokenizer,
        )
    
    def preprocess(self, stories, personas): # predict不截断
        data = []
        for story, persona in zip(stories, personas):
            text = persona + '</s>' + story
            data.append(text)
        result = {
            "encodings": self.tokenizer([i for i in data], truncation=True, padding=False, max_length=512),
        }
        return result

    def score(self, stories, personas):
        dataset = StoriumDataset(self.preprocess(stories, personas))
        predictions = self.trainer.predict(test_dataset=dataset).predictions # 没softmax过的...要手动softmax
        logit_score = torch.softmax(torch.tensor(predictions), dim=-1)[:, 1].numpy()
        # logit_score = np.mean(logit_score) # 所有batch取平均
        # print('mean logits score = ', logit_score)
        return logit_score

if __name__ == '__main__':
    predictor = RobertaPredictor()
    stories = ['I am a student', 'I am a student']
    personas = ['clever', 'clever']
    predictor.score(stories, personas)


