from data_utils import *
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, RobertaTokenizer

class NUFDataset(Dataset):
    def __init__(self, instances, labels, maxlen):
        self.maxlen = maxlen
        self.labels = labels
        #self.tokenizer = RobertaTokenizer.from_pretrained('../../ckpt/roberta-base')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, index):
        instance = self.instances[index]

        '''if (self.labels[index] <= 2):
            label = 0
        else:
            label = 1'''
        label = self.labels[index]-1

        # apply negative syntactic
        tokens = self.tokenizer.tokenize(instance)


        instance = self.tokenizer.encode_plus(instance,
                                              add_special_tokens=True,
                                              max_length=self.maxlen,
                                              pad_to_max_length=True,
                                              return_tensors="pt")
        input_ids = instance['input_ids']
        token_type_ids = instance['token_type_ids']
        attention_mask = instance['attention_mask']
        return input_ids, token_type_ids, attention_mask, label


if __name__ == "__main__":
    sents = [
        "i go to school.",
        "really? you don't like burger?"
    ]
    dataset = NUFDataset(sents)
    for x in dataset:
        print (x)
