from torch.utils.data import Dataset, DataLoader
import numpy as np
from MassConfig.ConfigBase import *
from MassDataLoader.ConllLoader import *
import truecase
import re


class ConllIsentityLoader(ConllLoader):
    def read_data(self, fpath):
        encoding = 'UTF-8'
        if ('CoNLL' in fpath):  # .split('/')[-3] ==
            encoding = 'unicode_escape'
        entries = open(fpath, 'r', encoding=encoding).read().strip().split("\n\n")
        sents, tags_li = [], []  # list of lists
        for entry in entries[:int(len(entries))]:
            words = [line.split()[0] for line in entry.splitlines()]
            tags = [line.split()[-1] for line in entry.splitlines()]


            sents.append(["[CLS]"] + self.truecase_sentence(words) + ["[SEP]"])         # trick
            tags_li.append(["<PAD>"] + tags + ["<PAD>"])
        return sents, tags_li


    def __init__(self, fpath):
        """
        fpath: [train|valid|test].txt
        """
        self.sents, self.tags = [], []
        if(len(fpath.split(','))>1):
            for fp in fpath.split(','):
                sentences, labels = self.read_data(fp)
                self.sents.extend(sentences)
                self.tags.extend(labels)
        if(len(fpath.split(','))==1):
            sentences, labels = self.read_data(fpath)
            self.sents.extend(sentences)
            self.tags.extend(labels)


    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags[idx] # words, tags: string list

        # We give credits only to the first piece.
        x, y = [], []               # list of ids
        is_heads = []
        is_entity = []

        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0] * (len(tokens) - 1)
            t = [t] + ["<PAD>"] * (len(tokens) - 1)  # <PAD>: no decision
            yy = [tag_to_ix[each] for each in t]  # (T,)


            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        for m in y:
            if(m > 1):
                is_entity.append(1)
            else:
                is_entity.append(0)


        assert len(x) == len(y) == len(is_heads), f"len(x)={len(x)}, len(y)={len(y)}, len(att_mask)={len(is_heads)}"


        # seqlen
        seqlen = len(y)

        if (seqlen > MAX_LEN_TOKEN_IDS):
            x = x[:MAX_LEN_TOKEN_IDS]
            y = y[:MAX_LEN_TOKEN_IDS]
            is_heads = is_heads[:MAX_LEN_TOKEN_IDS]
            is_entity = is_entity[:MAX_LEN_TOKEN_IDS]
            seqlen = len([h for h in is_heads if h == 1])
            words = words[:seqlen]
            tags = tags[:seqlen]

        seqlen = len(y)

        # to string
        words = " ".join(words)
        tags = " ".join(tags)

        return words, x, is_heads, tags, y, seqlen, is_entity

        # f = lambda x: x[:MAX_SENTENCE_LENGTH] if len(x)>MAX_SENTENCE_LENGTH else x
        # return words, f(x), f(att_mask), tags, f(y), seqlen


    def pad(self, batch):
        '''Pads to the longest sample'''
        f = lambda x: [sample[x] for sample in batch]
        words = f(0)
        is_heads = f(2)
        tags = f(3)
        seqlens = f(-2)
        maxlen = np.array(seqlens).max()


        # maxlen = MAX_SENTENCE_LENGTH
        f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
        x = f(1, maxlen)
        y = f(-3, maxlen)
        is_entity = f(-1, maxlen)
        is_heads = f(2, maxlen)


        f = torch.LongTensor
        # 排序
        # seqlens, indices = torch.LongTensor(seqlens).sort(dim=0, descending=True)
        # words = [words[i] for i in indices]
        # tags = [tags[i] for i in indices]
        # x = f(x)[indices]
        # y = f(y)[indices]
        # att_mask = f(att_mask)[indices]
        # is_entity = f(is_entity)[indices]

        assert len(x)==len(is_heads)==len(y), f'len(x)={len(x)}, len(is_heads)={len(is_heads)}, len(y)={len(y)}'
        return words, f(x), f(is_heads), tags, f(y), seqlens, f(is_entity)

# testset = "../data/CoNLL2003/test.txt"
# test_dataset = ConllIsentityLoader(testset)
#
# testloader = DataLoader(dataset=test_dataset,
#                         batch_size=16,
#                         shuffle=False,
#                         num_workers=4,
#                         collate_fn=pad)

# for idx, batch in enumerate(testloader):
#     words, x, is_heads, tags, y, seqlens, is_entity = batch
#     print("=====sanity check======")
#     print("words:", words[0])
#     print("x:", x.cpu().numpy()[0][:seqlens[0]])
#     print("tokens:", tokenizer.convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]])
#     print("is_heads:", is_heads[0])
#     print("y:", y.cpu().numpy()[0][:seqlens[0]])
#     print("tags:", tags[0])
#     print("is_entity:", is_entity[0])
#     print("seqlen:", seqlens[0])
#     print("=======================")

