import torch
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
import os
import random


class InputExample(object):

    def __init__(self, guid, text, label):

        self.guid = guid
        self.text = text
        self.label = label

class InputFeatures(object):

    def __init__(self, input_ids, input_mask, label_ids, masked_ids, entity_mask, nce_pos_ids, nce_neg_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.label_ids = label_ids
        self.masked_ids = masked_ids
        self.entity_mask = entity_mask
        self.nce_pos_ids = nce_pos_ids
        self.nce_neg_ids = nce_neg_ids

class MultiLabelTextProcessor():

    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.labels = None

    def _create_examples(self, df):
        """Creates examples for the training and dev sets."""
        examples = []
        text = []
        label = []
        guid = 0

        for i,row in enumerate(df.values):
            if not pd.isna(row[-1]):
                # Add leading label special token
                if row[-1] != 'O':
                    text.append('<' + row[-1] + '>')
                    label.append('O') # Use O as pseudo label for entity special token, if use another special token, need to caution !O code
                text.append(row[0])
                label.append(row[-1])
                # Add trailing label special token
                if row[-1] != 'O':
                    text.append('<' + row[-1] + '>')
                    label.append('O')
            elif text != []:
                examples.append(
                    InputExample(guid=guid, text=text, label=label))
                guid += 1
                text = []
                label = []
        return examples

    def get_examples(self, dsplit):

        data_df = pd.read_csv(os.path.join(self.data_dir, dsplit + ".txt"),
                              sep="\t", header=None, skip_blank_lines=False,
                              engine='python', error_bad_lines=False, quoting=3,
                              keep_default_na = False,
                              na_values=['']) #, '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null'])
        return self._create_examples(data_df)

class Data():
    def __init__(self, tokenizer, b_size, label_map, file_dir, mask_rate):

        #self.device = device
        #self.pretrain_model = pretrain_model
        self.tokenizer = tokenizer
        self.b_size = b_size
        self.label_map = label_map
        self.mask_rate = mask_rate

        #if not os.path.isfile('../../data/nermlm/test.pt'):
        self.datasets = self.create_dataset_files(file_dir)
        # if not os.path.isfile('../data/nermlm/test.pt'):
        #     self.create_features_files()

    def create_dataset_files(self, file_dir):
        processor = MultiLabelTextProcessor(file_dir)
        datasets = []
        for dsplit in ['train', 'dev', 'dev']:
            print(f"Generating dataloader for {dsplit}")
            examples = processor.get_examples(dsplit)
            features = self.convert_examples_to_features(examples, self.tokenizer, is_train=(dsplit == 'train'))
            input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
            input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
            label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
            print(np.array([f.masked_ids for f in features]).shape)
            print(np.array([f.entity_mask for f in features]).shape)
            masked_ids = torch.tensor([f.masked_ids for f in features], dtype=torch.long)
            entity_mask = torch.tensor([f.entity_mask for f in features], dtype=torch.long)
            nce_pos_ids = torch.tensor([f.nce_pos_ids for f in features], dtype=torch.long)
            print(np.array([f.nce_neg_ids for f in features]).shape)
            #print([f.nce_neg_ids for f in features[:3]])
            nce_neg_ids = torch.tensor([f.nce_neg_ids for f in features], dtype=torch.long)
            
            dataset = TensorDataset(input_ids, input_mask, label_ids, masked_ids, entity_mask, nce_pos_ids, nce_neg_ids)
            datasets.append(dataset)
            #dataloader = DataLoader(data, batch_size=b_size)

            #torch.save(dataset, '../../data/nermlm/' + dsplit + '.pt')
        return datasets

    def convert_examples_to_features(self, examples, tokenizer, max_seq_length=128, is_train=True):
        """Loads a data file into a list of `InputBatch`s."""

        features = []

        for example in examples:
            #print(example.text)
            #print(example.label)
            encoded = tokenizer(example.text,
                                padding = 'max_length',
                                truncation = True,
                                max_length=max_seq_length,
                                is_split_into_words = True)
            input_ids = encoded["input_ids"]
            input_mask = encoded["attention_mask"]

            # Insert X label for non-leading sub-word tokens
            subword_len = []
            for word in example.text:
                subword_len.append(len(tokenizer.tokenize(word)))

            subword_start = [0]
            subword_start.extend(np.cumsum(subword_len))
            subword_start = [x+1 for x in subword_start]
            masked_ids = [input_ids.copy() for i in range(30)]  # Use different masking for diff epoch, [list]*20 is bug coz it refer to the same list
            nce_pos_ids = [input_ids.copy() for i in range(30)]           
            nce_neg_ids = [[input_ids.copy() for i in range(7)] for j in range(30)]

            entity_mask = [[0] for i in range(30)]
            label_ids = [0]

            for i, label in enumerate(example.label):
                label_ids.append(self.label_map[label])
                label_ids.extend([0] * (subword_len[i]-1))

                # Mask named entities in sentence, and generate entity mask
                if label != "O":
                    #masked_ids[subword_start[i]:subword_start[i+1]] = [self.label_to_token_id(label) for count in range(subword_len[i])]
                    #entity_mask.extend([1] * subword_len[i])
                    #print("subword_start[i]", subword_start[i])
                    #print("subword_len[i]", subword_len[i])
                    #print("len(masked_ids)", len(masked_ids))
                    same_class_labels = ['B-'+label[2:], 'I-'+label[2:]]
                    diff_class_labels = [l for l in ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'B-MISC', 'I-MISC'] if l not in same_class_labels]
                    assert len(diff_class_labels) == 7

                    for count in range(subword_len[i]):
                        if subword_start[i]+count >= max_seq_length:
                            break
                        for k in range(30):
                            random.shuffle(diff_class_labels) #Diff subs for each epoch

                            if random.random() < self.mask_rate:
                                #masked_ids[k][subword_start[i]+count] = self.label_to_token_id(label)
                                masked_ids[k][subword_start[i]+count] = self.tokenizer.convert_tokens_to_ids('<mask>')
                                entity_mask[k].append(1)
                                # NCE samples
                                for j in range(7):
                                    assert input_ids[subword_start[i]-1] == self.label_to_token_id(label)
                                    nce_neg_ids[k][j][subword_start[i]-1] = self.label_to_token_id(diff_class_labels[j])
                                    try: # if sent is truncated, the trailing label token will be lost
                                        assert input_ids[subword_start[i]+subword_len[i]] == self.label_to_token_id(label) or input_ids[subword_start[i]+subword_len[i]] == 2 
                                        #,(input_ids, subword_start[i], subword_len[i], self.label_to_token_id(label))
                                        nce_neg_ids[k][j][subword_start[i]+subword_len[i]] = self.label_to_token_id(diff_class_labels[j])
                                    except IndexError :
                                        pass
                                    nce_neg_ids[k][j][subword_start[i]+count] = self.tokenizer.convert_tokens_to_ids('<mask>')

                                    alter_label = 'B'+label[1:] if label[0] == 'I' else 'I'+label[1:]
                                    nce_pos_ids[k][subword_start[i]-1] = self.label_to_token_id(alter_label)
                                    try: # if sent is truncated, the trailing label token will be lost
                                        nce_pos_ids[k][subword_start[i]+subword_len[i]] = self.label_to_token_id(alter_label)
                                    except IndexError :
                                        pass
                            else:
                                entity_mask[k].append(0)
                    # NCE samples
                    #same_class_labels = ['B-'+label[2:], 'I-'+label[2:]]
                    #diff_class_labels = [l for l in ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'B-MISC', 'I-MISC'] if l not in same_class_labels]
                    #assert len(diff_class_labels) == 7
                    #random.shuffle(diff_class_labels)
                    ##nce_neg_ids[subword_start[i]:subword_start[i+1]] = [self.label_to_token_id(random.choice(diff_class_labels)) for count in range(subword_len[i])]
                    #for j in range(7):
                    #    nce_neg_ids[j][subword_start[i]:subword_start[i+1]] = [self.label_to_token_id(diff_class_labels[j]) for count in range(subword_len[i])]
                    
                    #alter_label = 'B'+label[1:] if label[0] == 'I' else 'I'+label[1:]
                    #nce_pos_ids[subword_start[i]:subword_start[i+1]] = [self.label_to_token_id(alter_label) for count in range(subword_len[i])]
                else:
                    for count in range(subword_len[i]):
                        if subword_start[i]+count >= max_seq_length:
                            break
                        for k in range(30):
                            if is_train and random.random() < 0:
                                masked_ids[k][subword_start[i]+count] = self.tokenizer.convert_tokens_to_ids('<mask>') 
                            entity_mask[k].append(0)
                #print(masked_ids[0])
                #print(entity_mask[0])
                #print(np.array(masked_ids).shape)
                #print(np.array(entity_mask).shape)

            # Pad short sentence and truncate long sentence
            #print("len(label_ids)", len(label_ids))
            if len(label_ids) > max_seq_length:
                #print("Truncating...")
                label_ids = label_ids[:max_seq_length]
                for k in range(30):
                    masked_ids[k] = masked_ids[k][:max_seq_length]
                    entity_mask[k] = entity_mask[k][:max_seq_length]
                
                    for j in range(7):
                        nce_neg_ids[k][j] = nce_neg_ids[k][j][:max_seq_length]
                    nce_pos_ids[k] = nce_pos_ids[k][:max_seq_length]
            else:
                #print("Padding...")
                label_ids.extend([0] * (max_seq_length - len(label_ids)))
                for k in range(30):
                    masked_ids[k].extend([0] * (max_seq_length - len(masked_ids[k])))
                    entity_mask[k].extend([0] * (max_seq_length - len(entity_mask[k])))
                    for j in range(7):
                        nce_neg_ids[k][j].extend([0] * (max_seq_length - len(nce_neg_ids[k][j])))
                    nce_pos_ids[k].extend([0] * (max_seq_length - len(nce_pos_ids[k])))

            #print(np.array(masked_ids).shape)
            #print(np.array(entity_mask).shape)

            features.append(
                    InputFeatures(
                        input_ids=input_ids,
                        input_mask=input_mask,
                        label_ids=label_ids,
                        masked_ids=masked_ids,
                        entity_mask=entity_mask,
                        nce_pos_ids=nce_pos_ids,
                        nce_neg_ids=nce_neg_ids
                        ))
            #print("input_ids", input_ids)
            #print("input_mask", input_mask)
            #print("masked_ids", masked_ids)

        return features

    def label_to_token_id(self,label):
        label = '<' + label + '>'
        assert label in ['<O>', '<B-PER>', '<I-PER>', '<B-ORG>', '<I-ORG>', '<B-LOC>', '<I-LOC>', '<B-MISC>', '<I-MISC>']

        return self.tokenizer.convert_tokens_to_ids(label)








    # def create_features_files(self):
    #     for lang in self.lang_map.keys():
    #         for dsplit in ['train', 'dev', 'test']:
    #             dataloader = DataLoader(torch.load('../data/dataset/' + lang + '.' + dsplit + '.pt'), batch_size=self.b_size)
    #             features = self.extract_features(dataloader)
    #             torch.save(features, '../data/features/' + lang + '.' + dsplit + '.pt')
    #
    #
    # def extract_features(self, dataloader):
    #
    #     hid_states_list = []
    #     pooled_list = []
    #     print("Extracting features...")
    #
    #     for i, batch in enumerate(dataloader):
    #         batch = tuple(t.to(self.device) for t in batch)
    #         input_ids, input_mask, _, _ = batch
    #         hid_states, _ = self.pretrain_model(input_ids, input_mask)
    #         pooled = torch.sum(hid_states, axis=1) / torch.sum(input_mask, axis=1, keepdim=True)
    #         hid_states_list.append(hid_states)
    #         pooled_list.append(pooled)
    #
    #     features = (torch.cat(hid_states_list, dim=0), torch.cat(pooled_list, dim=0))
    #
    #     print("hid states dim is {}".format(features[0].size()))
    #     print("pooled dim is {}".format(features[1].size()))
    #
    #     return features
