from torch.utils.data import Dataset, DataLoader
import numpy as np
from MassConfig.ConfigBase import ALL_TAGS, ALL_TAGS_, tokenizer, tag_to_ix, ix_to_tag, MAX_LEN_TOKEN_IDS
import torch
import random
from MassUtils.MassUtil import truecase_sentence, split_dict, product_dict, transpose_list
import re
## (s1,s2,1/0)
class SiameseLoader():

    def construct_other_part_pair(self, tags_ctx_words, tag_1, ctx_word_idx_1, pair_kind):
        tag_2 = tag_1
        tag_1_ctx_words = tags_ctx_words.get(tag_1)
        if (pair_kind):
            ctx_word_idx_2 = (ctx_word_idx_1+1) % len(tag_1_ctx_words)
            # while (ctx_word_idx_2 == ctx_word_idx_1):
            #     ctx_word_idx_2 = random.randint(0, len(tag_1_ctx_words)-1)
            ctx_word_2 = tag_1_ctx_words[ctx_word_idx_2]
            y = 1
        else:

            while (tag_2 == tag_1 or len(tags_ctx_words.get(tag_2))<1):
                tag_2 = ALL_TAGS_[random.randint(0, len(ALL_TAGS_)-1)]
            tag_2_ctx_words = tags_ctx_words.get(tag_2)
            ctx_word_idx_2 = random.randint(0, len(tag_2_ctx_words)-1)
            ctx_word_2 = tag_2_ctx_words[ctx_word_idx_2]
            y = 0
        return ctx_word_2, y, tag_2

    def construct_pair(self, ctx_word_idx_1, tags_ctx_words, tag_1, pair_kind):
        tag_ctx_words = tags_ctx_words.get(tag_1)
        ctx_word_1 = tag_ctx_words[ctx_word_idx_1]
        words_1, w_idx_1, tags_1 = split_dict('ctx_word', ctx_word_1)
        ctx_word_2, y, tag_2 = self.construct_other_part_pair(tags_ctx_words, tag_1, ctx_word_idx_1, pair_kind)
        words_2, w_idx_2, tags_2 = split_dict('ctx_word', ctx_word_2)
        pair_values = [words_1, w_idx_1, words_2, w_idx_2, y, tag_1, tag_2, tags_1, tags_2]
        pair = product_dict('pair', pair_values)
        return pair

    def construct_posi_pairs(self, tags_ctx_words, isequal, maxSamples):
        posi_pairs = []
        for tag_1 in ALL_TAGS_:
            tag1_ctx_words = tags_ctx_words.get(tag_1)
            if (len(tag1_ctx_words) < 2):
                continue
            # if (not isequal):
            #     maxSamples = len(tag1_ctx_words)
            countMax = np.min([maxSamples, len(tag1_ctx_words)])
            count_true = 0
            while (count_true < countMax):
                ctx_word_idx = random.randint(0, len(tag1_ctx_words) - 1) if (isequal) else count_true
                pair = self.construct_pair(ctx_word_idx, tags_ctx_words, tag_1, pair_kind=True)
                posi_pairs.append(pair)
                count_true += 1
            self.logger.info(f"{tag_1}:{count_true}")
        return posi_pairs

    def construct_nega_pairs(self, tags_ctx_words, count_nega_pair):
        nega_pairs = []
        for i in range(count_nega_pair):
            tag_1 = ''
            while (tag_1 == '' or len(tags_ctx_words.get(tag_1)) < 1):
                tag_1 = ALL_TAGS_[random.randint(0, len(ALL_TAGS_) - 1)]     # 除开tag1
            tag1_ctx_words = tags_ctx_words.get(tag_1)
            ctx_word_idx = random.randint(0, len(tag1_ctx_words) - 1)
            sample = self.construct_pair(ctx_word_idx, tags_ctx_words, tag_1, pair_kind=False)
            nega_pairs.append(sample)
        return nega_pairs

    def construct_nega_pairs_(self, tags_ctx_words, isequal, maxSamples):
        nega_pairs = []
        for tag_1 in ALL_TAGS_:
            tag1_ctx_words = tags_ctx_words.get(tag_1)
            if (len(tag1_ctx_words) < 2):
                continue
            # if (not isequal):
            #     maxSamples = len(tag1_ctx_words)
            countMax = np.min([maxSamples, len(tag1_ctx_words)])
            count_true = 0
            while (count_true < countMax):
                ctx_word_idx = random.randint(0, len(tag1_ctx_words) - 1) if (isequal) else count_true
                pair = self.construct_pair(ctx_word_idx, tags_ctx_words, tag_1, pair_kind=False)
                nega_pairs.append(pair)
                count_true += 1
            self.logger.info(f"{tag_1}:{count_true}")
        return nega_pairs

    def get_test_pairs(self, tags_ctx_words, maxSamples):
        test_pairs = []
        for tag in ALL_TAGS_:
            tag_ctx_words = tags_ctx_words.get(tag)
            count = 0
            for ctx_word in tag_ctx_words:
                if(count>maxSamples):
                    break
                count += 1
                words, w_idx, tags = split_dict('ctx_word', ctx_word)
                assert tags[1:-1][w_idx]==tag, f"tag={tag}, tags[w_idx]={tags[w_idx]}"
                pair_values = [words, w_idx, words, w_idx, 1, tags[1:-1][w_idx], tags[1:-1][w_idx], tags, tags]
                pair = product_dict('pair', pair_values)
                test_pairs.append(pair)
        return test_pairs

    def get_train_pairs(self, tags_ctx_words, isequal, maxSamples):
        posi_pairs = self.construct_posi_pairs(tags_ctx_words, isequal, maxSamples)
        count_t_posi_pairs = len(posi_pairs)
        nega_pairs = self.construct_nega_pairs(tags_ctx_words, count_t_posi_pairs)

        return posi_pairs+nega_pairs

    def get_pairs(self, entries, params_list):
        maxSamples, istest, isequal = split_dict('params_list', params_list)

        if(istest):
            tags_ctx_words, pairs = self.entry_2_tags_ctx_words_pairs(entries, maxSamples)
        else:
            tags_ctx_words = self.entry_2_tags_ctx_words(entries)
            pairs = self.get_train_pairs(tags_ctx_words, isequal, maxSamples)
        return tags_ctx_words, pairs

    def get_wp_idx_list(self, token_list):
        wp_att_list = []
        for token_idx, token in enumerate(token_list):
            word_piece = tokenizer.tokenize(token)
            wp_att_list.extend([1] + [0] * (len(word_piece) - 1))
        wp_idx_list = []
        for att_i, att in enumerate(wp_att_list):
            if (att == 1):
                wp_idx_list.append(att_i)
        return wp_idx_list

    def chk_wp_idx(self, token_list, w_idx):
        wp_idx = w_idx
        for token_idx in range(w_idx + 1):
            token = token_list[token_idx]
            word_piece = tokenizer.tokenize(token)
            wp_idx += len(word_piece) - 1
            if(wp_idx+1>MAX_LEN_TOKEN_IDS-1):
                return False
        return True

    def remove_u00(self, words):
        for w_idx, word in enumerate(words):
            words[w_idx] = re.sub('[^\w\s]', '', words[w_idx])
            # if (words[w_idx] == ''):
            #     words[w_idx] = 'UNK'
        return words


    def entry_2_tags_ctx_words(self, entries):
        words_list, tags_list = [], []  # list of lists
        tags_ctx_words = {tag: [] for tag in ALL_TAGS_}
        for entry in entries:  # each sentence
            words = [line.split()[0] for line in entry.splitlines()]  # [:128]
            tags = [line.split()[-1] for line in entry.splitlines()]  # [:128]
            # tags = ['O' if(t.split('-')[-1]=='MISC') else t for t in tags]
            words_close = ["[CLS]"] + (truecase_sentence(words)) + ["[SEP]"]      # trick
            tags_close = ["<PAD>"] + tags + ["<PAD>"]
            words_list.append(words_close)
            tags_list.append(tags_close)

            # wp_idx_list = self.get_wp_idx_list(words_close)
            for t_idx, tag in enumerate(tags):
                # if(self.fpath.split('/')[-1]=='train.txt' and tag in ['O']):
                #     continue
                # if(wp_idx_list[t_idx]+1>MAX_LEN_TOKEN_IDS-1):
                if(self.chk_wp_idx(words_close, t_idx)):
                    ctx_word_value = [words_close, t_idx, tags_close]
                    ctx_word = product_dict('ctx_word', ctx_word_value)
                    tags_ctx_words.get(tag).append(ctx_word)

        return tags_ctx_words

    def entry_2_tags_ctx_words_pairs(self, entries, maxSamples):
        words_list, tags_list = [], []  # list of lists
        tags_ctx_words = {tag: [] for tag in ALL_TAGS_}
        tags_count = {tag: 0 for tag in ALL_TAGS_}

        pairs = []
        for entry in entries:  # each sentence
            words = [line.split()[0] for line in entry.splitlines()]  # [:128]
            tags = [line.split()[-1] for line in entry.splitlines()]  # [:128]
            # tags = ['O' if(t.split('-')[-1]=='MISC') else t for t in tags]
            words_close = ["[CLS]"] + (truecase_sentence(words)) + ["[SEP]"]      # trick
            tags_close = ["<PAD>"] + tags + ["<PAD>"]
            words_list.append(words_close)
            tags_list.append(tags_close)

            # wp_idx_list = self.get_wp_idx_list(words_close)
            for t_idx, tag in enumerate(tags):
                # if(wp_idx_list[t_idx]+1>MAX_LEN_TOKEN_IDS-1):
                if(self.chk_wp_idx(words_close, t_idx)):
                    ctx_word_value = [words_close, t_idx, tags_close]
                    ctx_word = product_dict('ctx_word', ctx_word_value)
                    if(tags_count.get(tag)<=maxSamples):
                        tags_count[tag] = tags_count[tag]+1
                        tags_ctx_words.get(tag).append(ctx_word)
                        pair_values = [words_close, t_idx, words_close, t_idx, 1, tags[t_idx], tags[t_idx], tags_close, tags_close]
                        pair = product_dict('pair', pair_values)
                        pairs.append(pair)
        return tags_ctx_words, pairs


    def read_data(self, fpath, params_list):
        # encoding = 'unicode_escape'
        self.logger.info(f'Load Data {fpath} ......')
        self.logger.info({params: params_list.get(params) for params in params_list})
        encoding = 'UTF-8'
        if ('CoNLL2002' in fpath):
            encoding = 'unicode_escape'
        entries = open(fpath, 'r', encoding=encoding).read().strip().split("\n\n")
        # tags_ctx_words = {B-LOC:[ctx_word], I-LOC:[ctx_word], B-PER:[ctx_word]}
        # ctx_word = {w_idx:2, words:[w1,w2,w3], tags:[t1,t2,t3]}
        # pairs = {words1:w1, word_idx1:i1, words2:w2, word_idx2:i2, y:1, tag1:t1, tag2:t2, tags:[t1,t2,t3], tags:[t1,t2,t3]}
        tags_ctx_words, pairs = self.get_pairs(entries, params_list)

        self.logger.info({tag: len(tags_ctx_words.get(tag)) for tag in ALL_TAGS_})
        self.logger.info(f"number of pairs : {len(pairs)}")
        return pairs

    def __init__(self, fpath, params_dict, logger):
        """
        fpath: [train|valid|test].txt
        """
        self.pairs = []
        self.logger = logger
        self.fpath = fpath
        if(isinstance(fpath, list)):
            for fp in fpath:
                pairs = self.read_data(fp, params_dict)
                self.pairs.extend(pairs)
        if (isinstance(fpath, str)):
            self.pairs = self.read_data(fpath, params_dict)

    def __len__(self):
        return len(self.pairs)

    def datas_2_inputs(self, words, tags, w_idx):
        word_piece_ids, att_mask, word_piece_tag_ids = [], [], []
        wp_idx = w_idx
        for word_idx, word in enumerate(words):
            # word = re.sub('[^\w\s]', '', word)
            word_piece = tokenizer.tokenize(word) if word not in ("[CLS]", "[SEP]") else [word]
            # if (len(word_piece) == 0):
            #     # words[word_idx] = '[UNK]'
            #     # print(words, word_idx, word_piece)
            #     word_piece = ['[UNK]']
            word_piece_id = tokenizer.convert_tokens_to_ids(word_piece)
            # print(word, word_piece)
            #
            #     word_piece_id = tokenizer.convert_tokens_to_ids()
            #     print(word, word_piece, word_piece_id)

            word_piece_ids.extend(word_piece_id)
            if(word_idx < w_idx+1):
                wp_idx += len(word_piece) - 1

            tag = tags[word_idx]
            word_piece_tag_ids.extend([tag_to_ix[tag]] * len(word_piece))
            # word_piece_tag_ids.extend([tag_to_ix[tag]] + [tag_to_ix['<PAD>']] * (len(word_piece) - 1))
            att_mask.extend([1] + [0] * (len(word_piece) - 1))

        assert len(word_piece_ids) == len(att_mask) == len(word_piece_tag_ids), f"len(word_piece_ids)={len(word_piece_ids)}, len(att_mask)={len(att_mask)}, len(word_piece_tag_ids)={len(word_piece_tag_ids)}"
        seqlen = len(word_piece_ids)
        if (seqlen > MAX_LEN_TOKEN_IDS):
            word_piece_ids = word_piece_ids[:MAX_LEN_TOKEN_IDS]
            word_piece_tag_ids = word_piece_tag_ids[:MAX_LEN_TOKEN_IDS]
            att_mask = att_mask[:MAX_LEN_TOKEN_IDS]
            seqlen = len(word_piece_ids)


        # try:

        assert word_piece_ids[wp_idx+1] == tokenizer.convert_tokens_to_ids(tokenizer.tokenize(words[w_idx+1]))[0], f"word_piece_id = " \
           f"{wp_idx+1}, {w_idx+1}, {word_piece_ids[wp_idx+1]}, {tokenizer.convert_tokens_to_ids(words[w_idx+1])},  " \
           f"{tokenizer.convert_ids_to_tokens(word_piece_ids[wp_idx+1])},{words[w_idx+1]}, {words}, {word_piece_ids}"

        return word_piece_ids, word_piece_tag_ids, wp_idx, seqlen, att_mask

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        words_1, word_idx_1, words_2, word_idx_2, y, tag_1, tag_2, tags_1, tags_2 = split_dict('pair', pair)

        word_piece_ids_1, word_piece_tag_ids_1, wp_idx_1, seqlen_1, att_mask_1 = self.datas_2_inputs(words_1, tags_1, word_idx_1)
        word_piece_ids_2, word_piece_tag_ids_2, wp_idx_2, seqlen_2, att_mask_2 = self.datas_2_inputs(words_2, tags_2, word_idx_2)

        return [words_1, words_2], [wp_idx_1, wp_idx_2], [tag_to_ix[tag_1], tag_to_ix[tag_2]], \
               [word_piece_ids_1, word_piece_ids_2], y, [att_mask_1, att_mask_2], [seqlen_1, seqlen_2], [word_idx_1, word_idx_2]

    def batch_align(self,seqlens, data_list):
        maxlen = np.max(seqlens)
        f = lambda xx, seqlen: [list(x) + [0] * (seqlen - len(x)) for x in xx]              # 0: <pad>
        align_data_list = []
        for data in data_list:
            align_data_list.append(f(data, maxlen))
        return align_data_list

    def pad(self, batch):
        '''Pads to the longest sample'''
        f = lambda x: transpose_list([sample[x] for sample in batch])
        words_1_2, wp_idx_1_2, wp_tags_1_2, wp_ids_1_2, att_mask_1_2, seqlens_1_2, word_idx_1_2 = f(0), f(1), f(2), f(3), f(5), f(6), f(7)
        f = lambda x: [sample[x] for sample in batch]
        y = f(4)

        wp_ids_1, att_mask_1 = self.batch_align(seqlens_1_2[0], [wp_ids_1_2[0], att_mask_1_2[0]])
        wp_ids_2, att_mask_2 = self.batch_align(seqlens_1_2[1], [wp_ids_1_2[1], att_mask_1_2[1]])

        f = torch.LongTensor
        wp_tags_1_2 = [f(wp_tags_1_2[0]), f(wp_tags_1_2[1])]
        wp_ids_1_2 = [f(wp_ids_1), f(wp_ids_2)]
        att_mask_1_2 = [f(att_mask_1), f(att_mask_2)]


        return words_1_2, wp_idx_1_2, wp_tags_1_2, wp_ids_1_2, f(y), att_mask_1_2, seqlens_1_2, word_idx_1_2






# testset = "../data/CoNLL2003/test.txt"
# #
# keys = ['maxSamples', 'isequal', 'shuffle', 'istest']
# values = [200, True, True, True]
# #
# test_dataset = SiameseLoader(testset, dict(zip(keys, values)))
# #
# testloader = DataLoader(dataset=test_dataset,
#                         batch_size=16,
#                         shuffle=False,
#                         num_workers=4,
#                         collate_fn=test_dataset.pad)
# #
# for idx, batch in enumerate(testloader):
#     words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
#
#     words_1, words_2 = words_1_2[0], words_1_2[1]
#     word_idx_1, word_idx_2 = word_idx_1_2[0], word_idx_1_2[1]
#     tag_1, tag_2 = tag_1_2[0], tag_1_2[1]
#     x_1, x_2 = x_1_2[0], x_1_2[1]
#     att_mask_1, att_mask_2 = att_mask_1_2[0], att_mask_1_2[1]
#     seqlen_1, seqlen_2 = seqlen_1_2[0], seqlen_1_2[1]
#
#     print("=====sanity check======")
#     print("words_1:", words_1_2[0][0])
#     print("words_2:", words_1_2[1][0])
#
#     print("word_idx_1:", word_idx_1_2[0][0])
#     print("word_idx_2:", word_idx_1_2[1][0])
#
#     print("tag_1:", tag_1_2[0][0])
#     print("tag_2:", tag_1_2[1][0])
#
#     print("x_1:", x_1_2[0].cpu().numpy()[0][:seqlen_1_2[0][0]])
#     print("x_2:", x_1_2[1].cpu().numpy()[0][:seqlen_1_2[1][0]])
#
#     print("y:", y.cpu().numpy()[0])
#
#     print("att_mask_1:", att_mask_1_2[0][0])
#     print("att_mask_2:", att_mask_1_2[1][0])
#
#     print("seqlen_1:", seqlen_1_2[0][0])
#     print("seqlen_2:", seqlen_1_2[1][0])
#     print("=======================")
