from data_utils import *
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer


class ABBADataset(Dataset):
    def __init__(self, contexts, responses, ctx_token_len=25, res_token_len=25):
        self.ctx_token_len = ctx_token_len
        self.res_token_len = res_token_len
        self.contexts = contexts
        self.responses = responses
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self._build_response_pool()

    def _build_response_pool(self):
        self.res_pool = self.responses

    def _get_fake_response(self):
        idx = random.randint(0, len(self.res_pool)-1)
        return self.res_pool[idx]

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, index):
        'Generates one sample of data'
        A = self.contexts[index]
        B = self.responses[index]
        label = 1

        # negative sampling
        if random.random() < 0.5:
            #res = self._get_fake_response()
            temp = A
            A = B
            B = temp
            label = 0

        # Encode the input
        input_ids, token_type_ids, mask_tokens, pos_ids = encode_truncate(
            self.tokenizer,
            A, B,
            ctx_token_len=self.ctx_token_len,
            res_token_len=self.res_token_len
        )

        return input_ids, token_type_ids, mask_tokens, pos_ids, label


