import csv
import os
import pickle
from collections import Counter
from datetime import datetime
from typing import Dict

import numpy as np
from pytorch_pretrained_bert import BertTokenizer

from Dependency_paring import task
from SSDM.decorators import auto_init_args, lazy_execute
import string
UNK_IDX = 0
UNK_WORD = "UUUNKKK"
EVAL_YEAR = "WNT2020_QE_si_en"
punc = string.punctuation


def get_pos_data(pos, tokenizer, tag2idx):
    sents, tags_li = [], []  # list of lists
    for sent in pos:
        words = [word_pos[0] for word_pos in sent]
        tags = [word_pos[1] for word_pos in sent]
        sents.append(words)
        tags_li.append(tags)
    tokens_ = []
    Y = []
    X = []
    for s, ts in zip(sents, tags_li):
        x, y = [], []
        is_heads = []
        """all_token = tokenizer.tokenize(" ".join(s))
        start = 0
        lenth = len(all_token)"""
        for w, t in zip(s, ts):
            """tokens = []
            for i in range(start, lenth):
                if all_token[i] in w.lower():
                    tokens.append(all_token[i])
                    if (i+1) < lenth:
                        for j in range(i+1, lenth):
                            if "#" in all_token[j] or (all_token[j] in punc and all_token[j] == w[-1]):
                                tokens.append(all_token[j])
                            else:
                                start = j
                                break
                    break"""
            tokens = tokenizer.tokenize(w.lower()) if w not in ("[CLS]", "[SEP]") else [w]
            if tokens:
                xx = tokenizer.convert_tokens_to_ids(tokens)

                is_head = [1] + [0] * (len(tokens) - 1)

                t = [t] + ["<pad>"] * (len(tokens) - 1)  # <PAD>: no decision
                yy = [tag2idx[each] for each in t]  # (T,)
                tokens_.append(tokens)
                x.extend(xx)
                is_heads.extend(is_head)
                y.extend(yy)

        assert len(x) == len(y) == len(is_heads), "len(x)={}, len(y)={}, len(is_heads)={}".format(len(x), len(y),
                                                                                                  len(is_heads))
        Y.append(y)
        X.append(x)
    return X, Y


def read_annotated_file(path, index="index"):
    indices = []
    originals = []
    translations = []
    z_means = []
    with open(path, mode="r", encoding="utf-8-sig") as csvfile:
        reader = csv.DictReader(csvfile, delimiter="\t", quoting=csv.QUOTE_NONE)
        if "QE" in path:
            for row in reader:
                originals.append(row["original"])
                translations.append(row["translation"])
                z_means.append(float(row['mean']))
        else:
            for row in reader:
                try:
                    z_means.append(float(row['score'].strip()))
                    originals.append(row["text_a"])
                    translations.append(row["text_b"])
                except ValueError:
                    print(row)
        """
        arr_mean = []
        arr = np.asarray(z_means)
        for x in z_means:
            x = float(x - np.min(arr)) / (np.max(arr) - np.min(arr))
            arr_mean.append(x)
        z_means = arr_mean
        """
    return {'year': 'WNT2020_QE_si_en', 'original': originals, 'translation': translations, 'z_mean': z_means}


def read_annotated_file_mean(path, index="index"):
    indices = []
    originals = []
    translations = []
    z_means = []
    with open(path, mode="r", encoding="utf-8-sig") as csvfile:
        reader = csv.DictReader(csvfile, delimiter="\t", quoting=csv.QUOTE_NONE)
        for row in reader:
            indices.append(row[index])
            originals.append(row["original"])
            translations.append(row["translation"])
            z_means.append(float(row["mean"]))
        arr_mean = []
        for x in z_means:
            x = float(x * 0.01)
            arr_mean.append(x)
        z_means = arr_mean

    return {'year': 'WNT2020_QE_si_en', 'index': indices, 'original': originals, 'translation': translations,
            'z_mean': z_means}


class data_holder:
    @auto_init_args
    def __init__(self, train_data, dev_data, test_data, vocab):
        self.inv_vocab = {i: w for w, i in vocab.items()}


class data_processor:
    @auto_init_args
    def __init__(self, train_path, eval_path, dp_1, dp_2, experiment):
        self.expe = experiment
        self.train_path = train_path
        self.eval_path = eval_path
        self.dp1 = dp_1
        self.dp2 = dp_2

    def process(self):
        if self.expe.config.pre_train_emb:
            fn = "pre_vocab_" + str(self.expe.config.vocab_size)
        else:
            fn = "vocab_" + str(self.expe.config.vocab_size)

        vocab_file = os.path.join(self.expe.config.vocab_file, fn)

        train_data = self._load_sent(
            self.train_path, file_name=self.train_path + ".pkl")

        if self.expe.config.pre_train_emb:
            W, vocab = self._build_pretrain_vocab(train_data, file_name=vocab_file)
        else:
            W, vocab = self._build_vocab(train_data, file_name=vocab_file)
        v = {}
        if self.expe.config.embed_type == 'lm':
            v = vocab.vocab
            self.expe.log.info("vocab size: {}".format(len(v)))
        else:
            self.expe.log.info("vocab size: {}".format(len(vocab)))

        train_data = self._data_to_idx_dp(train_data, vocab, dp1=self.dp1, dp2=self.dp2)

        if self.eval_path is not None:
            if isinstance(self.eval_path, list):
                if len(self.eval_path) == 2:
                    new_data = read_annotated_file(self.eval_path[0])
                    self.expe.log.info(
                        "dev year: {}".format(new_data['year']))
                    data_idx = self._data_to_idx([new_data['original'], new_data['translation']], vocab)
                    dev_data = {new_data['year']: {}}
                    dev_data[new_data['year']]['1'] = [data_idx[1], data_idx[0], new_data['z_mean']]

                    test_data = {'2020': {}}
                    new_data = read_annotated_file(self.eval_path[1])
                    self.expe.log.info(
                        "test year: {}".format(new_data['year']))
                    data_idx = self._data_to_idx([new_data['original'], new_data['translation']], vocab)
                    test_data['2020']['1'] = [data_idx[1], data_idx[0], new_data['z_mean']]
                else:
                    new_data = read_annotated_file(self.eval_path[0])
                    self.expe.log.info(
                        "dev year: {}".format(new_data['year']))
                    data_idx = self._data_to_idx([new_data['original'], new_data['translation']], vocab)
                    dev_data = {new_data['year']: {}}
                    dev_data[new_data['year']]['1'] = [data_idx[1], data_idx[0], new_data['z_mean']]
                    name = ["ar-en", "es-en", "tr-en"]
                    test_data = {'ar-en': {}, 'es-en': {}, 'tr-en': {}}
                    for i in range(1, len(self.eval_path)):
                        new_data = read_annotated_file(self.eval_path[i])
                        self.expe.log.info(
                            "test year: {}".format(new_data['year']))
                        data_idx = self._data_to_idx([new_data['original'], new_data['translation']], vocab)
                        test_data[name[i-1]]['1'] = [data_idx[1], data_idx[0], new_data['z_mean']]

                data = data_holder(
                    train_data=train_data,
                    dev_data={EVAL_YEAR: dev_data[EVAL_YEAR]},
                    test_data={y: test_data[y] for y in test_data},
                    vocab=v)
            else:
                eval_data = self._load_from_pickle(self.eval_path)
                new_data = dict()
                # new_data['2020'] = dict()
                # eval_path = 'data/en-zh-QE/test20.enzh.df.short.tsv'
                # new_data_ = read_annotated_file(eval_path)
                # self.expe.log.info(
                #    "year: {}, #domain: {}".format('2020', '1'))
                # data_idx = self._data_to_idx([new_data_['original'], new_data_['translation']], vocab)
                # new_data['2020'][1] = [data_idx[1], data_idx[0], new_data_['z_mean']]
                for year, data in sorted(eval_data.items()):
                    self.expe.log.info(
                        "year: {}, #domain: {}".format(year, len(data)))
                    new_data[year] = dict()
                    for n, d in data.items():
                        if self.expe.config.embed_type == 'lm':
                            data_idx = self._data_to_idx_bert_dev([d[0], d[1], d[2]], vocab)
                            new_data[year][n] = [data_idx[0], data_idx[1], data_idx[2]]
                        else:
                            data_idx = self._data_to_idx([d[0], d[1]], vocab)
                            new_data[year][n] = [data_idx[0], data_idx[1], d[2]]

                data = data_holder(
                    train_data=train_data,
                    dev_data={EVAL_YEAR: new_data[EVAL_YEAR]},
                    test_data={y: new_data[y] for y in new_data if y != EVAL_YEAR},
                    vocab=v)
        else:
            data = data_holder(
                train_data=train_data,
                dev_data=None,
                test_data=None,
                vocab=vocab)

        return data, vocab, W

    @lazy_execute("_load_from_pickle")
    def _load_sent(self, path):
        data_pair1 = []
        data_pair2 = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip().lower()
                if len(line):
                    line = line.split('\t')
                    if len(line) == 2:
                        data_pair1.append(line[0].split(" "))
                        data_pair2.append(line[1].split(" "))
                    else:
                        print(line)
                        self.expe.log.warning("unexpected data: " + line)
        assert len(data_pair1) == len(data_pair2)
        return data_pair1, data_pair2

    def _data_to_idx_bert_dev(self, data, vocab):
        idx_pair1 = []
        idx_pair2 = []
        score = []
        for d1, d2, d3 in zip(*data):

            jump = 1
            if isinstance(d2, str):
                token = vocab.tokenize(d2)
            else:
                token = vocab.tokenize(' '.join(d2))
            s2 = vocab.convert_tokens_to_ids(token)
            for s in s2:
                if s not in [100, 119, 0]:
                    idx_pair2.append(s2)
                    jump = 0
                    break
            if jump:
                continue
            if isinstance(d1, str):
                s1 = vocab.convert_tokens_to_ids(vocab.tokenize(d1))
            else:
                s1 = vocab.convert_tokens_to_ids(vocab.tokenize(' '.join(d1)))
            idx_pair1.append(s1)
            score.append(d3)
        return np.array(idx_pair1), np.array(idx_pair2), score

    def _data_to_idx(self, data, vocab):
        idx_pair1 = []
        idx_pair2 = []
        Y = []
        if isinstance(vocab, dict):
            for d1, d2 in zip(*data):
                s1 = [vocab.get(w, UNK_IDX) for w in d1]
                idx_pair1.append(s1)
                s2 = [vocab.get(w, UNK_IDX) for w in d2]
                idx_pair2.append(s2)
            return np.array(idx_pair1), np.array(idx_pair2)
        else:
            for d1, d2 in zip(*data):
                jump = 1
                if isinstance(d2, str):
                    token = vocab.tokenize(d2)
                else:
                    token = vocab.tokenize(' '.join(d2))
                s2 = vocab.convert_tokens_to_ids(token)
                for s in s2:
                    if s not in [100, 119, 0]:
                        idx_pair2.append(s2)
                        jump = 0
                        break
                if jump:
                    continue
                if isinstance(d1, str):
                    tokens = vocab.tokenize(d1)
                else:
                    tokens = vocab.tokenize(' '.join(d1))
                s1 = vocab.convert_tokens_to_ids(tokens)
                idx_pair1.append(s1)
            return np.array(idx_pair1), np.array(idx_pair2)

    def _data_to_idx_dp(self, data, vocab, dp1, dp2):
        idx_pair1 = []
        idx_pair2 = []

        dp1_depth_list = []
        dp1_dictance_list = []
        dp2_depth_list = []
        dp2_dictance_list = []
        depth_get = task.ParseDepthTask
        distance_get = task.ParseDistanceTask
        for d1 in dp1:
            d1_dep, sentence1 = depth_get.labels(d1)
            s1, d1_dep = self.load_depth_tag(sentence1, d1_dep, vocab)
            dp1_depth_list.append(d1_dep)
            d1_dis = distance_get.labels(d1)
            d1_dis = self.load_distance_tag(sentence1, d1_dis, vocab)
            dp1_dictance_list.append(d1_dis)
            idx_pair1.append(s1)
        for d2 in dp2:
            d2_dep, sentence2 = depth_get.labels(d2)
            s2, d2_dep = self.load_depth_tag(sentence2, d2_dep, vocab)
            dp2_depth_list.append(d2_dep)
            d2_dis = distance_get.labels(d2)
            d2_dis = self.load_distance_tag(sentence2, d2_dis, vocab)
            dp2_dictance_list.append(d2_dis)
            idx_pair2.append(s2)
        return np.array(idx_pair1), np.array(idx_pair2), \
               (np.array(dp1_depth_list), np.array(dp1_dictance_list)), \
               (np.array(dp2_depth_list), np.array(dp2_dictance_list))

    def load_depth_tag(self, words, tags, vocab):
        assert len(words) == len(tags)
        x, y = [], []
        for w, t in zip(words, tags):
            tokens = vocab.tokenize(w.lower()) if w not in ("[CLS]", "[SEP]") else [w]
            if tokens:
                xx = vocab.convert_tokens_to_ids(tokens)

                t = [t] + [-1] * (len(tokens) - 1)
                x.extend(xx)
                y.extend(t)
            assert len(x) == len(y), "len(x)={}, len(y)={}".format(len(x), len(y))

        return x, y

    def load_distance_tag(self, words, tags, vocab):
        assert len(words) == len(tags)
        x, y = [], []
        for w, t in zip(words, tags):
            tokens = vocab.tokenize(w.lower()) if w not in ("[CLS]", "[SEP]") else [w]
            if tokens:
                xx = vocab.convert_tokens_to_ids(tokens)

                t = [t] + [np.ones(t.shape) * (-1)] * (len(tokens) - 1)
                x.extend(xx)
                y.extend(t)
            assert len(x) == len(y), "len(x)={}, len(y)={}".format(len(x), len(y))

        return y

    def _load_paragram_embedding(self, path):
        with open(path, encoding="latin-1") as fp:
            # word_vectors: word --> vector
            word_vectors = {}
            for line in fp:
                line = line.strip("\n").split(" ")
                word_vectors[line[0]] = np.array(
                    list(map(float, line[1:])), dtype='float32')
        vocab_embed = word_vectors.keys()
        embed_dim = word_vectors[next(iter(vocab_embed))].shape[0]
        return word_vectors, vocab_embed, embed_dim

    def _load_glove_embedding(self, path):
        with open(path, 'r', encoding='utf8') as fp:
            # word_vectors: word --> vector
            word_vectors = {}
            for line in fp:
                line = line.strip("\n").split(" ")
                word_vectors[line[0]] = np.array(list(line[1:]), dtype='float32')
        vocab_embed = word_vectors.keys()
        embed_dim = word_vectors[next(iter(vocab_embed))].shape[0]

        return word_vectors, vocab_embed, embed_dim

    def _create_vocab_from_data(self, data):
        vocab = Counter()
        for sent1, sent2 in zip(*data):
            for w in sent1 + sent2:
                vocab[w] += 1

        ls = vocab.most_common(self.expe.config.vocab_size)
        self.expe.log.info(
            '#Words: %d -> %d' % (len(vocab), len(ls)))
        for key in ls[:5]:
            self.expe.log.info(key)
        self.expe.log.info('...')
        for key in ls[-5:]:
            self.expe.log.info(key)
        vocab = [x[0] for x in ls]

        # 0: unk, 1: bos, 2: eos
        vocab = {w: index + 3 for (index, w) in enumerate(vocab)}
        vocab[UNK_WORD] = UNK_IDX
        vocab["<bos>"] = 1
        vocab["<eos>"] = 2

        return vocab

    @lazy_execute("_load_from_pickle")
    def _build_vocab(self, train_data):
        vocab = self._create_vocab_from_data(train_data)
        return None, vocab

    @lazy_execute("_load_from_pickle")
    def _build_pretrain_vocab(self, train_data, tokenizer_args: Dict = {}):
        self.expe.log.info("loading embedding from: {}"
                           .format(self.expe.config.embed_file))
        if self.expe.config.embed_type.lower() == "glove":
            word_vectors, vocab_embed, embed_dim = \
                self._load_glove_embedding(self.expe.config.embed_file)
        elif self.expe.config.embed_type.lower() == "paragram":
            word_vectors, vocab_embed, embed_dim = \
                self._load_paragram_embedding(self.expe.config.embed_file)
        elif self.expe.config.embed_type.lower() == "lm":
            vocab = BertTokenizer.from_pretrained(self.expe.config.ml_token)
            W = 'lm'
            return W, vocab
        else:
            raise NotImplementedError(
                "invalid embedding type: {}".format(
                    self.expe.config.embed_type))

        vocab = self._create_vocab_from_data(train_data)

        W = np.random.uniform(
            -np.sqrt(3.0 / embed_dim), np.sqrt(3.0 / embed_dim),
            size=(len(vocab), embed_dim)).astype('float32')
        n = 0
        for w, i in vocab.items():
            if w in vocab_embed:
                W[i, :] = word_vectors[w]
                n += 1
        self.expe.log.info(
            "{}/{} vocabs are initialized with {} embeddings."
                .format(n, len(vocab), self.expe.config.embed_type))

        return W, vocab

    def _load_from_pickle(self, file_name):
        self.expe.log.info("loading from {}".format(file_name))
        with open(file_name, "rb") as fp:
            data = pickle.load(fp)
        return data


class batch_accumulator:
    def __init__(self, mega_batch, p_scramble, init_batch1, init_batch2):
        assert len(init_batch1) == len(init_batch2) == mega_batch
        self.p_scramble = p_scramble
        self.mega_batch = mega_batch
        self.pool = [init_batch1, init_batch2]

    def update(self, new_batch1, new_batch2):
        self.pool[0].pop(0)
        self.pool[1].pop(0)

        self.pool[0].append(new_batch1)
        self.pool[1].append(new_batch2)

        assert len(self.pool[0]) == len(self.pool[1]) == self.mega_batch

    def get_batch(self):
        data1 = np.concatenate(self.pool[0])
        data2 = np.concatenate(self.pool[1])

        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])

        input_data1 = \
            np.zeros((len(data1), max_len1)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len1)).astype("float32")

        tgt_data1 = \
            np.zeros((len(data1), max_len1 + 2)).astype("float32")
        tgt_mask1 = \
            np.zeros((len(data1), max_len1 + 2)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len2)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len2)).astype("float32")

        tgt_data2 = \
            np.zeros((len(data2), max_len2 + 2)).astype("float32")
        tgt_mask2 = \
            np.zeros((len(data2), max_len2 + 2)).astype("float32")

        for i, (sent1, sent2) in enumerate(zip(data1, data2)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            tgt_data1[i, :len(sent1) + 2] = \
                np.asarray([1] + list(sent1) + [2]).astype("float32")
            tgt_mask1[i, :len(sent1) + 2] = 1.

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

            tgt_data2[i, :len(sent2) + 2] = \
                np.asarray([1] + list(sent2) + [2]).astype("float32")
            tgt_mask2[i, :len(sent2) + 2] = 1.

        return input_data1, input_mask1, tgt_data1, tgt_mask1, \
               input_data2, input_mask2, tgt_data2, tgt_mask2


class bow_accumulator(batch_accumulator):
    def __init__(self, mega_batch, p_scramble, init_batch1, init_tgt1,
                 init_batch2, init_tgt2, vocab_size):
        assert len(init_batch1) == len(init_batch2) == mega_batch
        self.p_scramble = p_scramble
        self.mega_batch = mega_batch
        self.vocab_size = vocab_size
        self.pool = [init_batch1, init_tgt1, init_batch2, init_tgt2]

    def update(self, new_batch1, new_tgt1, new_batch2, new_tgt2):
        self.pool[0].pop(0)
        self.pool[1].pop(0)
        self.pool[2].pop(0)
        self.pool[3].pop(0)

        self.pool[0].append(new_batch1)
        self.pool[1].append(new_tgt1)
        self.pool[2].append(new_batch2)
        self.pool[3].append(new_tgt2)

        assert len(self.pool[0]) == len(self.pool[1]) == self.mega_batch

    def get_batch(self):
        data1 = np.concatenate(self.pool[0])
        data2 = np.concatenate(self.pool[2])

        tgt_data1 = np.concatenate(self.pool[1])
        tgt_data2 = np.concatenate(self.pool[3])

        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])

        input_data1 = \
            np.zeros((len(data1), max_len1)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len1)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len2)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len2)).astype("float32")

        for i, (sent1, sent2) in enumerate(zip(data1, data2)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

        return input_data1, input_mask1, tgt_data1, \
               input_data2, input_mask2, tgt_data2


class minibatcher:
    @auto_init_args
    def __init__(self, data1, data2, batch_size, score_func,
                 shuffle, mega_batch, p_scramble, *args, **kwargs):
        self._reset()

    def __len__(self):
        return len(self.idx_pool)

    def _reset(self):
        self.pointer = 0
        idx_list = np.arange(len(self.data1))
        if self.shuffle:
            np.random.shuffle(idx_list)
        self.idx_pool = [idx_list[i: i + self.batch_size]
                         for i in range(0, len(self.data1), self.batch_size)]

        if self.mega_batch > 1:
            init_mega_ids = self.idx_pool[-self.mega_batch:]
            init_mega1, init_mega2 = [], []
            for idx in init_mega_ids:
                d1, d2 = self.data1[idx], self.data2[idx]
                init_mega1.append(d1)
                init_mega2.append(d2)
            self.mega_batcher = batch_accumulator(
                self.mega_batch, self.p_scramble, init_mega1, init_mega2)

    def _select_neg_sample(self, data, data_mask,
                           cand, cand_mask, ctgt, ctgt_mask, no_diag):
        score_matrix = self.score_func(
            data, data_mask, cand, cand_mask)

        if no_diag:
            diag_idx = np.arange(len(score_matrix))
            score_matrix[diag_idx, diag_idx] = -np.inf
        neg_idx = np.argmax(score_matrix, 1)

        neg_data = cand[neg_idx]
        neg_mask = cand_mask[neg_idx]

        tgt_data = ctgt[neg_idx]
        tgt_mask = ctgt_mask[neg_idx]

        max_neg_len = int(neg_mask.sum(-1).max())
        neg_data = neg_data[:, : max_neg_len]
        neg_mask = neg_mask[:, : max_neg_len]

        max_tgt_len = int(tgt_mask.sum(-1).max())
        tgt_data = tgt_data[:, : max_tgt_len]
        tgt_mask = tgt_mask[:, : max_tgt_len]

        assert neg_mask.sum(-1).max() == max_neg_len
        return score_matrix, neg_data, neg_mask, tgt_data, tgt_mask

    def _pad(self, data1, data2):
        assert len(data1) == len(data2)
        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])
        max_len = max([max_len1, max_len2])

        input_data1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        tgt_data1 = \
            np.zeros((len(data1), max_len + 2)).astype("float32")
        tgt_mask1 = \
            np.zeros((len(data1), max_len + 2)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        tgt_data2 = \
            np.zeros((len(data2), max_len + 2)).astype("float32")
        tgt_mask2 = \
            np.zeros((len(data2), max_len + 2)).astype("float32")

        for i, (sent1, sent2) in enumerate(zip(data1, data2)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            tgt_data1[i, :len(sent1) + 2] = \
                np.asarray([1] + list(sent1) + [2]).astype("float32")
            tgt_mask1[i, :len(sent1) + 2] = 1.

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

            tgt_data2[i, :len(sent2) + 2] = \
                np.asarray([1] + list(sent2) + [2]).astype("float32")
            tgt_mask2[i, :len(sent2) + 2] = 1.

        if self.mega_batch > 1:
            cand1, cand_mask1, ctgt1, ctgt_mask1, \
            cand2, cand_mask2, ctgt2, ctgt_mask2 = \
                self.mega_batcher.get_batch()
            _, neg_data1, neg_mask1, ntgt1, ntgt_mask1 = \
                self._select_neg_sample(
                    input_data1, input_mask1, cand2,
                    cand_mask2, ctgt2, ctgt_mask2, False)
            _, neg_data2, neg_mask2, ntgt2, ntgt_mask2 = \
                self._select_neg_sample(
                    input_data2, input_mask2, cand1,
                    cand_mask1, ctgt1, ctgt_mask1, False)
            self.mega_batcher.update(data1, data2)

            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_mask1, tgt_data2, tgt_mask2,
                    neg_data1, neg_mask1, ntgt1, ntgt_mask1,
                    neg_data2, neg_mask2, ntgt2, ntgt_mask2]
        else:
            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_mask1, tgt_data2, tgt_mask2,
                    None, None, None, None,
                    None, None, None, None]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer == len(self.idx_pool):
            self._reset()
            raise StopIteration()

        idx = self.idx_pool[self.pointer]
        data1, data2 = self.data1[idx], self.data2[idx]
        self.pointer += 1
        return self._pad(data1, data2) + [idx]


class bow_minibatcher:
    @auto_init_args
    def __init__(self, data1, data2, vocab_size, batch_size,
                 score_func, shuffle, mega_batch, p_scramble,
                 *args, **kwargs):
        self._reset()

    def __len__(self):
        return len(self.idx_pool)

    def _reset(self):
        self.pointer = 0
        idx_list = np.arange(len(self.data1))
        if self.shuffle:
            np.random.shuffle(idx_list)
        self.idx_pool = [idx_list[i: i + self.batch_size]
                         for i in range(0, len(self.data1), self.batch_size)]

        if self.mega_batch > 1:
            init_mega_ids = self.idx_pool[-self.mega_batch:]
            init_mega1, init_mega2, init_tgt1, init_tgt2 = [], [], [], []
            for idx in init_mega_ids:
                d1, d2 = self.data1[idx], self.data2[idx]
                init_mega1.append(d1)
                init_mega2.append(d2)
                t1 = np.zeros((len(d1), self.vocab_size)).astype("float32")
                t2 = np.zeros((len(d2), self.vocab_size)).astype("float32")
                for i, (s1, s2) in enumerate(zip(d1, d2)):
                    t1[i, :] = np.bincount(s1, minlength=self.vocab_size)
                    t2[i, :] = np.bincount(s2, minlength=self.vocab_size)
                init_tgt1.append(t1)
                init_tgt2.append(t2)
            self.mega_batcher = bow_accumulator(
                self.mega_batch, self.p_scramble, init_mega1, init_tgt1,
                init_mega2, init_tgt2, self.vocab_size)

    def _select_neg_sample(self, data, data_mask,
                           cand, cand_mask, ctgt, no_diag):
        score_matrix = self.score_func(
            data, data_mask, cand, cand_mask)

        if no_diag:
            diag_idx = np.arange(len(score_matrix))
            score_matrix[diag_idx, diag_idx] = -np.inf
        neg_idx = np.argmax(score_matrix, 1)

        neg_data = cand[neg_idx]
        neg_mask = cand_mask[neg_idx]

        tgt_data = ctgt[neg_idx]

        max_neg_len = int(neg_mask.sum(-1).max())
        neg_data = neg_data[:, : max_neg_len]
        neg_mask = neg_mask[:, : max_neg_len]

        assert neg_mask.sum(-1).max() == max_neg_len
        return score_matrix, neg_data, neg_mask, tgt_data

    def _pad(self, data1, data2):
        assert len(data1) == len(data2)
        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])

        input_data1 = \
            np.zeros((len(data1), max_len1)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len1)).astype("float32")
        tgt_data1 = \
            np.zeros((len(data1), self.vocab_size)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len2)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len2)).astype("float32")
        tgt_data2 = \
            np.zeros((len(data2), self.vocab_size)).astype("float32")

        for i, (sent1, sent2) in enumerate(zip(data1, data2)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            tgt_data1[i, :] = np.bincount(sent1, minlength=self.vocab_size)

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

            tgt_data2[i, :] = np.bincount(sent2, minlength=self.vocab_size)

        if self.mega_batch > 1:
            cand1, cand_mask1, ctgt1, \
            cand2, cand_mask2, ctgt2 = \
                self.mega_batcher.get_batch()
            _, neg_data1, neg_mask1, ntgt1 = \
                self._select_neg_sample(
                    input_data1, input_mask1, cand2,
                    cand_mask2, ctgt2, False)
            _, neg_data2, neg_mask2, ntgt2 = \
                self._select_neg_sample(
                    input_data2, input_mask2, cand1,
                    cand_mask1, ctgt1, False)
            self.mega_batcher.update(data1, tgt_data1, data2, tgt_data2)

            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                    neg_data1, neg_mask1, ntgt1, ntgt1,
                    neg_data2, neg_mask2, ntgt2, ntgt2]
        else:
            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                    None, None, None, None,
                    None, None, None, None]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer == len(self.idx_pool):
            self._reset()
            raise StopIteration()

        idx = self.idx_pool[self.pointer]
        data1, data2 = self.data1[idx], self.data2[idx]
        self.pointer += 1
        return self._pad(data1, data2) + [idx]


class tree_bow_minibatcher:
    @auto_init_args
    def __init__(self, data1, data2, dp1, dp2, vocab_size, batch_size,
                 score_func, shuffle, mega_batch, p_scramble,
                 *args, **kwargs):
        self.dp1_dep = dp1[0]
        self.dp1_dis = dp1[1]
        self.dp2_dep = dp2[0]
        self.dp2_dis = dp2[1]
        self._reset()
        self.mega_batch = mega_batch

    def __len__(self):
        return len(self.idx_pool)

    def _reset(self):
        self.pointer = 0
        idx_list = np.arange(len(self.dp1_dep))
        if self.shuffle:
            np.random.shuffle(idx_list)
        self.idx_pool = [idx_list[i: i + self.batch_size]
                         for i in range(0, len(self.dp1_dep), self.batch_size)]

        if self.mega_batch > 1:
            init_mega_ids = self.idx_pool[-self.mega_batch:]
            init_mega1, init_mega2, init_tgt1, init_tgt2 = [], [], [], []
            for idx in init_mega_ids:
                d1, d2 = self.data1[idx], self.data2[idx]
                init_mega1.append(d1)
                init_mega2.append(d2)
                t1 = np.zeros((len(d1), self.vocab_size)).astype("float32")
                t2 = np.zeros((len(d2), self.vocab_size)).astype("float32")
                for i, (s1, s2) in enumerate(zip(d1, d2)):
                    t1[i, :] = np.bincount(s1, minlength=self.vocab_size)
                    t2[i, :] = np.bincount(s2, minlength=self.vocab_size)
                init_tgt1.append(t1)
                init_tgt2.append(t2)
            self.mega_batcher = bow_accumulator(
                self.mega_batch, self.p_scramble, init_mega1, init_tgt1,
                init_mega2, init_tgt2, self.vocab_size)

    def _select_neg_sample(self, data, data_mask,
                           cand, cand_mask, ctgt, no_diag):
        score_matrix = self.score_func(
            data, data_mask, cand, cand_mask)

        if no_diag:
            diag_idx = np.arange(len(score_matrix))
            score_matrix[diag_idx, diag_idx] = -np.inf
        neg_idx = np.argmax(score_matrix, 1)

        neg_data = cand[neg_idx]
        neg_mask = cand_mask[neg_idx]

        tgt_data = ctgt[neg_idx]
        max_len = data.shape[1]

        neg_data_ = np.zeros((neg_data.shape[0], data.shape[1])).astype("float32")
        neg_mask_ = np.zeros((neg_data.shape[0], data.shape[1])).astype("float32")
        if neg_data.shape[1] >= max_len:
            neg_data = neg_data_ + neg_data[:, : data.shape[1]]
            neg_mask = neg_mask_ + neg_mask[:, : data.shape[1]]
        else:
            neg_data = np.pad(neg_data, ((0, 0), (0, max_len - neg_data.shape[1])), 'constant', constant_values=(0))
            neg_mask = np.pad(neg_mask, ((0, 0), (0, max_len - neg_mask.shape[1])), 'constant', constant_values=(0))

        # assert neg_mask.sum(-1).max() == max_neg_len
        return score_matrix, neg_data, neg_mask, tgt_data

    def _pad(self, data1, data2, tags_1_dep, tags_1_dis, tags_2_dep, tags_2_dis):
        assert len(data1) == len(data2) == len(tags_1_dep) == len(tags_2_dep)
        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])
        max_len = max([max_len1, max_len2])
        """if max_len1 > max_len3:
            max_len3 = max_len1
        else:
            max_len1 = max_len3"""
        input_data1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        tgt_data1 = \
            np.zeros((len(data1), self.vocab_size)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        tgt_data2 = \
            np.zeros((len(data2), self.vocab_size)).astype("float32")

        input_d11 = \
            np.ones((len(tags_1_dep), max_len)).astype("int64") * (-1)
        input_d21 = \
            np.zeros((len(tags_2_dep), max_len)).astype("int64") * (-1)
        input_d12 = \
            np.ones((len(tags_1_dis), max_len, max_len)).astype("int64")
        input_d22 = \
            np.zeros((len(tags_2_dis), max_len, max_len)).astype("int64")

        for i, (sent1, sent2, dp11, dp12, dp21, dp22) in enumerate(
                zip(data1, data2, tags_1_dep, tags_1_dis, tags_2_dep, tags_2_dis)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            tgt_data1[i, :] = np.bincount(sent1, minlength=self.vocab_size)

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

            tgt_data2[i, :] = np.bincount(sent2, minlength=self.vocab_size)

            input_d11[i, :len(dp11)] = \
                np.asarray(list(dp11)).astype("int64")
            input_d21[i, :len(dp21)] = \
                np.asarray(list(dp21)).astype("int64")
            dp12 = np.array(dp12)
            dp12 = np.pad(dp12, pad_width=((0, max_len - dp12.shape[0]), (0, max_len - dp12.shape[1])), mode="constant",
                          constant_values=(-1))
            input_d12[i] = np.asarray(list(dp12)).astype("int64")
            dp22 = np.array(dp22)
            dp22 = np.pad(dp22, pad_width=((0, max_len - dp22.shape[0]), (0, max_len - dp22.shape[1])), mode="constant",
                          constant_values=(-1))
            input_d22[i] = np.asarray(list(dp22)).astype("int64")

        input_data3 = (input_d11, input_d21, input_d12, input_d22)

        if self.mega_batch > 1:
            cand1, cand_mask1, ctgt1, \
            cand2, cand_mask2, ctgt2 = \
                self.mega_batcher.get_batch()

            _, neg_data1, neg_mask1, ntgt1 = \
                self._select_neg_sample(
                    input_data1, input_mask1, cand2,
                    cand_mask2, ctgt2, False)
            _, neg_data2, neg_mask2, ntgt2 = \
                self._select_neg_sample(
                    input_data2, input_mask2, cand1,
                    cand_mask1, ctgt1, False)
            self.mega_batcher.update(data1, tgt_data1, data2, tgt_data2)

            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                    neg_data1, neg_mask1, ntgt1, ntgt1,
                    neg_data2, neg_mask2, ntgt2, ntgt2, input_data3]

        return [input_data1, input_mask1, input_data2, input_mask2,
                tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                None, None, None, None,
                None, None, None, None, input_data3]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer == len(self.idx_pool):
            self._reset()
            raise StopIteration()

        idx = self.idx_pool[self.pointer]
        tags_1_dep, tags_1_dis, tags_2_dep, tags_2_dis = self.dp1_dep[idx], self.dp1_dis[idx], self.dp2_dep[idx], \
                                                         self.dp2_dis[idx]
        data1, data2 = self.data1[idx], self.data2[idx]
        self.pointer += 1
        return self._pad(data1, data2, tags_1_dep, tags_1_dis, tags_2_dep, tags_2_dis) + [idx]


class pos_bow_minibatcher:
    @auto_init_args
    def __init__(self, data1, data2, pos, vocab_size, batch_size,
                 score_func, shuffle, mega_batch, p_scramble,
                 *args, **kwargs):
        self.y = np.array(pos)
        self._reset()
        self.mega_batch = mega_batch

    def __len__(self):
        return len(self.idx_pool)

    def _reset(self):
        self.pointer = 0
        idx_list = np.arange(len(self.y))
        if self.shuffle:
            np.random.shuffle(idx_list)
        self.idx_pool = [idx_list[i: i + self.batch_size]
                         for i in range(0, len(self.y), self.batch_size)]

        if self.mega_batch > 1:
            init_mega_ids = self.idx_pool[-self.mega_batch:]
            init_mega1, init_mega2, init_tgt1, init_tgt2 = [], [], [], []
            for idx in init_mega_ids:
                d1, d2 = self.data1[idx], self.data2[idx]
                init_mega1.append(d1)
                init_mega2.append(d2)
                t1 = np.zeros((len(d1), self.vocab_size)).astype("float32")
                t2 = np.zeros((len(d2), self.vocab_size)).astype("float32")
                for i, (s1, s2) in enumerate(zip(d1, d2)):
                    t1[i, :] = np.bincount(s1, minlength=self.vocab_size)
                    t2[i, :] = np.bincount(s2, minlength=self.vocab_size)
                init_tgt1.append(t1)
                init_tgt2.append(t2)
            self.mega_batcher = bow_accumulator(
                self.mega_batch, self.p_scramble, init_mega1, init_tgt1,
                init_mega2, init_tgt2, self.vocab_size)

    def _select_neg_sample(self, data, data_mask,
                           cand, cand_mask, ctgt, no_diag):
        score_matrix = self.score_func(
            data, data_mask, cand, cand_mask)

        if no_diag:
            diag_idx = np.arange(len(score_matrix))
            score_matrix[diag_idx, diag_idx] = -np.inf
        neg_idx = np.argmax(score_matrix, 1)

        neg_data = cand[neg_idx]
        neg_mask = cand_mask[neg_idx]

        tgt_data = ctgt[neg_idx]

        max_neg_len = int(neg_mask.sum(-1).max())
        neg_data = neg_data[:, : max_neg_len]
        neg_mask = neg_mask[:, : max_neg_len]

        assert neg_mask.sum(-1).max() == max_neg_len
        return score_matrix, neg_data, neg_mask, tgt_data

    def _pad(self, data1, data2, pos_):
        assert len(data1) == len(data2) == len(pos_)
        max_len1 = max([len(sent) for sent in data1])
        max_len2 = max([len(sent) for sent in data2])
        max_len3 = max([len(p) for p in pos_])
        max_len = max([max_len1, max_len2, max_len3])
        """if max_len1 > max_len3:
            max_len3 = max_len1
        else:
            max_len1 = max_len3"""
        input_data1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        input_mask1 = \
            np.zeros((len(data1), max_len)).astype("float32")
        tgt_data1 = \
            np.zeros((len(data1), self.vocab_size)).astype("float32")

        input_data2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        input_mask2 = \
            np.zeros((len(data2), max_len)).astype("float32")
        tgt_data2 = \
            np.zeros((len(data2), self.vocab_size)).astype("float32")

        input_data3 = \
            np.zeros((len(pos_), max_len)).astype("int64")

        for i, (sent1, sent2, Pos) in enumerate(zip(data1, data2, pos_)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent1 = np.random.permutation(sent1)
                sent2 = np.random.permutation(sent2)

            input_data1[i, :len(sent1)] = \
                np.asarray(list(sent1)).astype("float32")
            input_mask1[i, :len(sent1)] = 1.

            tgt_data1[i, :] = np.bincount(sent1, minlength=self.vocab_size)

            input_data2[i, :len(sent2)] = \
                np.asarray(list(sent2)).astype("float32")
            input_mask2[i, :len(sent2)] = 1.

            tgt_data2[i, :] = np.bincount(sent2, minlength=self.vocab_size)

            input_data3[i, :len(Pos)] = \
                np.asarray(list(Pos)).astype("int64")

        if self.mega_batch > 1:
            cand1, cand_mask1, ctgt1, \
            cand2, cand_mask2, ctgt2 = \
                self.mega_batcher.get_batch()

            _, neg_data1, neg_mask1, ntgt1 = \
                self._select_neg_sample(
                    input_data1, input_mask1, cand2,
                    cand_mask2, ctgt2, False)
            _, neg_data2, neg_mask2, ntgt2 = \
                self._select_neg_sample(
                    input_data2, input_mask2, cand1,
                    cand_mask1, ctgt1, False)
            self.mega_batcher.update(data1, tgt_data1, data2, tgt_data2)

            return [input_data1, input_mask1, input_data2, input_mask2,
                    tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                    neg_data1, neg_mask1, ntgt1, ntgt1,
                    neg_data2, neg_mask2, ntgt2, ntgt2, input_data3]

        return [input_data1, input_mask1, input_data2, input_mask2,
                tgt_data1, tgt_data1, tgt_data2, tgt_data2,
                None, None, None, None,
                None, None, None, None, input_data3]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer == len(self.idx_pool):
            self._reset()
            raise StopIteration()

        idx = self.idx_pool[self.pointer]
        tags = self.y[idx]
        data1, data2 = self.data1[idx], self.data2[idx]
        self.pointer += 1
        return self._pad(data1, data2, tags) + [idx]


class single_tree_bow_minibatcher:
    @auto_init_args
    def __init__(self, data, dp, vocab_size, batch_size,
                 score_func, shuffle, mega_batch, p_scramble,
                 *args, **kwargs):
        self.dp_dep = dp[0]
        self.dp_dis = dp[1]
        self._reset()
        self.mega_batch = mega_batch

    def __len__(self):
        return len(self.idx_pool)

    def _reset(self):
        self.pointer = 0
        idx_list = np.arange(len(self.dp_dep))
        if self.shuffle:
            np.random.shuffle(idx_list)
        self.idx_pool = [idx_list[i: i + self.batch_size]
                         for i in range(0, len(self.dp_dep), self.batch_size)]

    def _pad(self, datas, tags_1_dep, tags_1_dis):
        assert len(datas) == len(tags_1_dep)
        max_len = max([len(sent) for sent in datas])
        input_data1 = \
            np.zeros((len(datas), max_len)).astype("float32")
        input_mask1 = \
            np.zeros((len(datas), max_len)).astype("float32")
        tgt_data1 = \
            np.zeros((len(datas), self.vocab_size)).astype("float32")

        input_d11 = \
            np.ones((len(tags_1_dep), max_len)).astype("int64") * (-1)
        input_d12 = \
            np.ones((len(tags_1_dis), max_len, max_len)).astype("int64")

        for i, (sent, dp11, dp12) in enumerate(zip(datas, tags_1_dep, tags_1_dis)):
            if np.random.choice(
                    [True, False],
                    p=[self.p_scramble, 1 - self.p_scramble]).item():
                sent = np.random.permutation(sent)

            input_data1[i, :len(sent)] = \
                np.asarray(list(sent)).astype("float32")
            input_mask1[i, :len(sent)] = 1.

            tgt_data1[i, :] = np.bincount(sent, minlength=self.vocab_size)

            input_d11[i, :len(dp11)] = \
                np.asarray(list(dp11)).astype("int64")
            dp12 = np.array(dp12)
            dp12 = np.pad(dp12, pad_width=((0, max_len - dp12.shape[0]), (0, max_len - dp12.shape[1])), mode="constant",
                          constant_values=(-1))
            input_d12[i] = np.asarray(list(dp12)).astype("int64")

        input_data3 = (input_d11, input_d12)

        return [input_data1, input_mask1, None, None,
                tgt_data1, tgt_data1, None, None,
                None, None, None, None,
                None, None, None, None, input_data3]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer == len(self.idx_pool):
            self._reset()
            raise StopIteration()

        idx = self.idx_pool[self.pointer]
        tags_1_dep, tags_1_dis = self.dp_dep[idx], self.dp_dis[idx]
        datas = self.data[idx]
        self.pointer += 1
        return self._pad(datas, tags_1_dep, tags_1_dis) + [idx]
