from data_utils import *
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer

class MLMDataset(Dataset):
    def __init__(self, instances, maxlen):
        self.maxlen = maxlen
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, index):
        instance = self.instances[index]

        tokens = self.tokenizer.tokenize(instance)
        instance = self.tokenizer.encode_plus(instance,
                                         add_special_tokens=True,
                                         max_length=self.maxlen,
                                         pad_to_max_length=True,
                                         return_tensors="pt")
        input_ids = instance['input_ids']
        token_type_ids = instance['token_type_ids']
        attention_mask = instance['attention_mask']

        # mask a token
        sampling_length = min(len(tokens)+2, self.maxlen)
        mask_idx = torch.LongTensor([random.randint(1,sampling_length-2)])
        label = torch.LongTensor([input_ids[0][mask_idx].item()])
        input_ids[0][mask_idx] = 103 # [MASK] token <- 103

        return input_ids, token_type_ids, attention_mask, mask_idx, label

