from data_utils import *
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer


class ABBADataset(Dataset):
    def __init__(self, contexts, responses, ctx_token_len=25, res_token_len=25):
        self.ctx_token_len = ctx_token_len
        self.res_token_len = res_token_len
        self.contexts = contexts
        self.responses = responses
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self._build_response_pool()

    def _build_response_pool(self):
        self.res_pool = self.responses

    def _get_fake_response(self):
        idx = random.randint(0, len(self.res_pool)-1)
        return self.res_pool[idx]          #这是最基本的随机采样，任取一个不相关的response

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, index):      #  这里就很巧妙，数据集只是context-response 2个文件，但是可以产生带标签的正负样本之一
        'Generates one sample of data'
        A = self.contexts[index]
        B = self.responses[index]
        label = 1

        # negative sampling
        if random.random() < 0.5:
            #res = self._get_fake_response()
            temp = A
            A = B
            B = temp     # AB 变成 BA
            label = 0

        # Encode the input
        input_ids, token_type_ids, mask_tokens, pos_ids = encode_truncate(
            self.tokenizer,
            A, B,
            ctx_token_len=self.ctx_token_len,
            res_token_len=self.res_token_len
        )

        return input_ids, token_type_ids, mask_tokens, pos_ids, label


if __name__ == "__main__":
    sents = [
        "i go to school.",
        "really? you don't like burger?"
    ]
    dataset = NUPDataset(sents)
    for x in dataset:
        print (x)
