from data_utils import *
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, RobertaTokenizer


class IESDataset(Dataset):
    def __init__(self, contexts, responses, labels, ctx_token_len=25, res_token_len=25):
        self.ctx_token_len = ctx_token_len
        self.res_token_len = res_token_len
        self.contexts = contexts
        self.responses = responses
        self.labels = labels
        self.tokenizer = RobertaTokenizer.from_pretrained('../../ckpt/roberta-base')
        self._build_response_pool()

    def _build_response_pool(self):
        self.res_pool = self.responses

    def _get_fake_response(self):
        idx = random.randint(0, len(self.res_pool)-1)
        return self.res_pool[idx]

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, index):

        ctx = self.contexts[index]
        res = self.responses[index]

        label = self.labels[index] - 1

        # Encode the input
        input_ids, token_type_ids, mask_tokens, pos_ids = encode_truncate(
            self.tokenizer,
            ctx, res,
            ctx_token_len=self.ctx_token_len,
            res_token_len=self.res_token_len
        )

        return input_ids, token_type_ids, mask_tokens, pos_ids, label
