# coding:utf-8
#%%
import os
import copy
import torch
import numpy as np
from collections import Counter
from torchtext import data, vocab
from torch.nn import functional as F
#%%
class TextPairProcessor():
    def __init__(self, data_dir, data_path_list, w2v_path, mode, batch_size, gpu="-1"):
        super(TextPairProcessor, self).__init__()
        self.w2v_path = w2v_path
        self.TEXT = data.Field(batch_first=True, include_lengths=True)
        if mode == "regression":
            self.LABEL = data.Field(sequential=False, unk_token=None, use_vocab=False, tensor_type=torch.FloatTensor)
        else:
            self.LABEL = data.Field(sequential=False, unk_token=None, use_vocab=False)

        train_path, dev_path, test_path = data_path_list
        self.train, self.dev, self.test = data.TabularDataset.splits(
            path=data_dir,
            train=train_path,
            validation=dev_path,
            test=test_path,
            format='json',
            fields={'s': ('s', self.TEXT), 't': ('t', self.TEXT), 'label': ('label', self.LABEL)}
        )
        self.TEXT.build_vocab(self.train, self.dev, self.test)
        self.TEXT.vocab.stoi['<unk>'] = 1
        self.TEXT.vocab.stoi['<pad>'] = 0
        self.TEXT.vocab.itos[0] = '<pad>'
        self.TEXT.vocab.itos[1] = '<unk>'
        self.load_w2v()
        self.embed_num = len(self.TEXT.vocab)
        
        self.train_iter = data.BucketIterator(
            self.train,
            batch_size=batch_size,
            shuffle=True,
            # device="cuda:"+gpu if gpu!='-1' else "cpu",
            sort_key=lambda x: data.interleave_keys(len(x.ques1), len(x.ques2)))
        
        _, self.dev_iter, self.test_iter = data.BucketIterator.splits(
            (self.train, self.dev, self.test),
            batch_sizes=[batch_size] * 3,
            # device="cuda:"+gpu if gpu!='-1' else "cpu",
            shuffle=False,
            sort=False)
        
        self.max_word_len = max([len(w) for w in self.TEXT.vocab.itos])
        # for <pad>
        self.char_vocab = {'': 0}
        # for <unk> and <pad>
        self.characterized_words = [[0] * self.max_word_len, [0] * self.max_word_len]
        self.build_char_vocab()

    def get_label_map(self, task):
        labels = [str(i) for i in range(2)]
        label_map = dict(zip(labels, labels))
        return label_map
    
    def load_w2v(self):
        if os.path.isfile(self.w2v_path):
            stoi, w2v, dim = torch.load(self.w2v_path)
            print("Load Embeddings From {}".format(self.w2v_path))
            all_words, oov_words = 0, 0
            # self.TEXT.vocab = vocab.Vocab(Counter(stoi), specials=['<pad>', '<unk>'])
            # self.MASK = copy.deepcopy(self.TEXT)
            self.TEXT.vocab.vectors = torch.Tensor(len(self.TEXT.vocab), dim)
            for i, token in enumerate(self.TEXT.vocab.itos):
                # embdding
                wv_index = stoi.get(token, None)
                if wv_index is not None:
                    self.TEXT.vocab.vectors[i] = w2v[wv_index]
                elif token == "<pad>":
                    self.TEXT.vocab.vectors[i] = torch.Tensor.zero_(self.TEXT.vocab.vectors[i])
                else:
                    self.TEXT.vocab.vectors[i] = torch.nn.init.uniform_(self.TEXT.vocab.vectors[i], -0.25, 0.25)
                    oov_words += 1
                all_words += 1
            print("Total {} Words, with {} OOVs .... Loading Done.".format(all_words, oov_words))
        else:
            print("Error: Need word embedding pt file")
            exit(1)
    
    def characterize(self, batch):
        """
        :param batch: Pytorch Variable with shape (batch, seq_len)
        :return: Pytorch Variable with shape (batch, seq_len, max_word_len)
        """
        batch = batch.data.cpu().numpy().astype(int).tolist()
        return [[self.characterized_words[w] for w in words] for words in batch]
    
    def build_char_vocab(self):
        # for normal words
        for word in self.TEXT.vocab.itos[2:]:
            chars = []
            for c in list(word):
                if c not in self.char_vocab:
                    self.char_vocab[c] = len(self.char_vocab)
                chars.append(self.char_vocab[c])
            chars.extend([0] * (self.max_word_len - len(word)))
            self.characterized_words.append(chars)



# %%
