"""
Contains a list of all encoders.
"""
import torch
from torch import nn
from sklearn.neighbors import NearestNeighbors
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from autoencoders.rae import RAE
from autoencoders.autoencoder import AutoEncoder
from autoencoders.rnn_encoder import RNNEncoder
from autoencoders.rnn_decoder import RNNDecoder
from autoencoders.transformer_encoder import TransformerEncoder
from autoencoders.transformer_decoder import TransformerDecoder
from autoencoders.transformer_decoder_simple import SimpleTransformerDecoder
from autoencoders.bert_encoder import BERTEncoder
from autoencoders.pretrained_encoder import PretrainedEncoder
from transformers import BertTokenizer
from tokenizers import CharBPETokenizer, SentencePieceBPETokenizer
# from autoencoders.rnnae import RNNAE
from emb2emb.utils import Namespace, word_index_mapping
import pickle
import os
import json
import copy

HUGGINGFACE_TOKENIZERS = ["CharBPETokenizer", "SentencePieceBPETokenizer"]


class Encoder(nn.Module):
    """
    An encoder always takes a list of length 'b' as input and outputs a batch of
    embeddings of size 'b'.

    """

    def __init__(self, config):
        super(Encoder, self).__init__()
        self.use_lookup = config['use_lookup']
        self.lookup_table = {}

    def lookup(self, S_batch):
        if self.use_lookup:
            existing = []
            non_existing = []
            for i, S in enumerate(S_batch):
                if S in self.lookup_table:
                    existing.append((i, self.lookup_table(S)))
                else:
                    non_existing.append((i, S))

            return existing, non_existing

        else:
            return [], zip(list(range(len(S_batch))), S_batch)

    def encode(self, S_list):
        """
        To be implemented. Takes a list of strings and returns a list of embeddings. 
        """
        pass

    def forward(self, S_batch):
        """
        Turns a list of strings into a list of embeddings. First checks if
        the embeddings have already been computed.
        """

        batch_size = len(S_batch)
        embeddings, length = self.encode(S_batch)

        return embeddings, length


class Decoder(nn.Module):
    """
    A decoder takes a batch of embeddings of size 'b' as input and outputs a 
    batch of predictions (training time) or a list of texts (test time).

    """

    def __init__(self):
        super(Decoder, self).__init__()

    def predict(self, S_batch, target_batch=None):
        """
        To be implemented. Takes a batch of embeddings and returns a batch of
        predictions. At training time, target_batch contains a list of target
        sentences.
        """
        pass

    def prediction_to_text(self, predictions):
        """
        Takes a list of batch of embeddings of size b and returns a list of texts
        of length b.
        """
        pass

    def forward(self, embeddings, target_batch=None):
        """
        Turns a list of strings into a list of embeddings. First checks if
        the embeddings have already been computed.
        """
        outputs = self.predict(
            embeddings, target_batch=target_batch if self.training else None)
        if self.training:
            return outputs
        else:
            return self.prediction_to_text(outputs)


def tokenize(s):
    # TODO: more sophisticated tokenization
    return s.split()


def get_tokenizer(tokenizer, location='bert-base-uncased'):
    if tokenizer == "BERT":
        return BertTokenizer.from_pretrained(location)
    elif tokenizer in HUGGINGFACE_TOKENIZERS:
        # TODO: do we need to pass more options to the file?
        tok = eval(tokenizer)(vocab_file=location + '-vocab.json',
                              merges_file=location + '-merges.txt')
        tok.add_special_tokens(["[PAD]", "<unk>", "<SOS>", "<EOS>"])
        return tok


def get_autoencoder(config):
    if os.path.exists(config["default_config"]):
        with open(config["default_config"]) as f:
            model_config_dict = json.load(f)
    else:
        model_config_dict = {}
    with open(os.path.join(config["modeldir"], "config.json")) as f:
        orig_model_config = json.load(f)
        model_config_dict.update(orig_model_config)
        model_config = Namespace()
        model_config.__dict__.update(model_config_dict)

    if model_config.encoder == "BERTEncoder":
        tokenizer = get_tokenizer(
            model_config.tokenizer, model_config.BERTEncoder["bert_location"])
        model_config.__dict__["vocab_size"] = tokenizer.vocab_size
        model_config.__dict__["sos_idx"] = tokenizer.cls_token_id
        model_config.__dict__["eos_idx"] = tokenizer.sep_token_id
        model_config.__dict__["unk_idx"] = tokenizer.unk_token_id
        model_config.__dict__["pad_idx"] = tokenizer.pad_token_id
    else:
        tokenizer = get_tokenizer(
            model_config.tokenizer, model_config.tokenizer_location)
        model_config.__dict__["vocab_size"] = tokenizer.get_vocab_size()
        model_config.__dict__["sos_idx"] = tokenizer.token_to_id("<SOS>")
        model_config.__dict__["eos_idx"] = tokenizer.token_to_id("<EOS>")
        model_config.__dict__["unk_idx"] = tokenizer.token_to_id("<unk>")
        model_config.__dict__["pad_idx"] = tokenizer.token_to_id("[PAD]")

    model_config.__dict__["device"] = config["device"]

    encoder_config, decoder_config = copy.deepcopy(
        model_config), copy.deepcopy(model_config)
    encoder_config.__dict__.update(model_config.__dict__[model_config.encoder])
    encoder_config.__dict__["tokenizer"] = tokenizer
    decoder_config.__dict__.update(model_config.__dict__[model_config.decoder])

    if model_config.encoder == "RNNEncoder":
        encoder = RNNEncoder(encoder_config)
    elif model_config.encoder == "BERTEncoder":
        encoder = BERTEncoder(encoder_config)
    elif model_config.encoder == "PretrainedEncoder":
        encoder = PretrainedEncoder(encoder_config)
    elif model_config.encoder == "TransformerEncoder":
        encoder = TransformerEncoder(encoder_config)

    if model_config.decoder == "RNNDecoder":
        decoder = RNNDecoder(decoder_config)
    if model_config.decoder == "TransformerDecoder":
        decoder = TransformerDecoder(decoder_config)
    if model_config.decoder == "SimpleTransformerDecoder":
        decoder = SimpleTransformerDecoder(decoder_config)

    model = AutoEncoder(encoder, decoder, tokenizer, model_config)

    checkpoint = torch.load(os.path.join(
        config["modeldir"], model_config.model_file), map_location=config["device"])

    if "best_model_state_dict" in checkpoint and checkpoint["best_model_state_dict"] is not None:
        model.load_state_dict(checkpoint["best_model_state_dict"])  
    else:      
        model.load_state_dict(checkpoint["model_state_dict"])

    return model


class RAEEncoder(Encoder):
    def __init__(self, config):
        super(RAEEncoder, self).__init__(config)
        state_dict = torch.load(config["model_path"])
        self.max_sequence_len = 100
        params = Namespace(embedding_size=2048, input_size=300,
                           max_sequence_len=self.max_sequence_len, device=config['device'])
        self.rae = RAE(params)
        self.rae.load_state_dict(state_dict)
        self.glove = config["glove"]
        self.device = config['device']

    def _to_glove(self, tokenized):
        tensor_list = []
        for token_list in tokenized:

            array_list = []
            for t in token_list:
                if t in self.glove:
                    array_list.append(self.glove[t])
            if len(array_list) == 0:
                array_list.append(self.glove["."])
            tensor_list.append(torch.tensor(
                array_list, device=self.device).float())

        glovenized = pad_sequence(
            tensor_list, batch_first=True, padding_value=0.)  # pad to max length
        # len_diff = 100 - glovenized.size(1) # difference to max len
        #padd = torch.zeros(glovenized.size(0), len_diff, glovenized.size(2), device = self.device)
        #glovenized = torch.cat([glovenized, padd], dim = 1)
        return glovenized

    def encode(self, S_list):
        tokenized = [tokenize(s) for s in S_list]

        # lookup GloVe embeddings
        glovenized = self._to_glove(tokenized)

        #
        glovenized = self.rae.scale_embedding(glovenized)
        embedding, recursion_count = self.rae.encode(glovenized)

        # append recursion count of the embedding
        recursion_count = torch.tensor(
            recursion_count, device=self.device).float()
        embedding = embedding.squeeze()
        recursion_count = recursion_count.view(
            1, 1).expand(embedding.size(0), 1)
        # we need to normalize the recursion count so it doesnt dominate the
        # MLE
        recursion_count = recursion_count / self.max_sequence_len
        embedding = torch.cat([embedding, recursion_count], dim=1)

        return embedding


class RAEDecoder(Decoder):
    def __init__(self, config):
        super(RAEDecoder, self).__init__()
        state_dict = torch.load(config["model_path"])
        self.max_sequence_len = 100
        params = Namespace(embedding_size=2048, input_size=300,
                           max_sequence_len=self.max_sequence_len, device=config["device"])
        self.rae = RAE(params)
        self.rae.load_state_dict(state_dict)
        self.glove = config["glove"]
        self.device = config['device']

        id = 0
        id_to_word = {}
        glove_matrix = np.zeros((len(self.glove), 300))
        for k, v in self.glove.items():
            glove_matrix[id, :] = v
            id_to_word[id] = k
            id += 1

        self.id_to_word = id_to_word
        self.nearestneighbors = NearestNeighbors(
            n_neighbors=1, algorithm="auto", metric="cosine")
        self.nearestneighbors.fit(glove_matrix)

    def _from_glove(self):
        pass

    def predict(self, S_batch):
        """
        To be implemented. Takes a batch of embeddings and returns a batch of
        predictions. 
        """
        recursion_count = S_batch[0, -
                                  1].item()  # we concatenated the recursion count while encoding

        # we need to unnormalize the recursion count again.
        recursion_count = recursion_count * float(self.max_sequence_len)

        # ... round it to the next integer
        recursion_count = round(recursion_count)

        # and make sure it doesnt exceed the max len
        recursion_count = min(self.max_sequence_len, recursion_count)

        output_embedding = S_batch[:, :-1].unsqueeze(1)
        output = self.rae.decode(output_embedding, recursion_count)
        output = self.rae.descale_embedding(output)
        return output

    def prediction_to_text(self, predictions):
        """
        Takes a list of batch of embeddings of size b and returns a list of texts
        of length b.
        """

        # predict nearest neighbors
        predictions = predictions.cpu().numpy()
        batch_size = predictions.shape[0]
        predictions = np.reshape(predictions, (-1, 300))
        nns = self.nearestneighbors.kneighbors(
            predictions, n_neighbors=1, return_distance=False)
        nns = np.reshape(nns, (batch_size, -1))

        # map to text
        texts = []
        for i in range(batch_size):
            t = [self.id_to_word[nns[i, j]] for j in range(nns.shape[1])]
            t = " ".join(t)
            texts.append(t)

        return texts


class AEEncoder(Encoder):
    def __init__(self, config):
        super(AEEncoder, self).__init__(config)
        self.device = config["device"]
        self.model = get_autoencoder(config)
        self.use_lookup = self.model.encoder.variational
        self.max_sequence_len = config["max_sequence_len"]
        self.remove_sos_and_eos = config["remove_sos_and_eos"]

    def _prepare_batch(self, indexed, lengths):
        X = pad_sequence([torch.tensor(index_list, device=self.device)
                          for index_list in indexed], batch_first=True, padding_value=0)

        if self.max_sequence_len > -1:
            X = X[:, :self.max_sequence_len]
            lengths = [min(l, self.max_sequence_len) for l in lengths]
        lengths, idx = torch.sort(torch.tensor(
            lengths, device=self.device).long(), descending=True)
        return X[idx], lengths, idx

    def _undo_batch(self, encoded, lengths, sort_idx, l0loss=None):
        ret = [[] for _ in range(encoded.shape[0])]
        lens = [0] * len(lengths)
        if l0loss is not None:
            l0losses = [0] * len(lengths)
        for i, c in zip(sort_idx, range(encoded.shape[0])):
            ret[i] = encoded[c]
            lens[i] = lengths[c]

            if l0loss is not None:
                l0losses[i] = l0loss[c]

        if l0loss is not None:
            return torch.stack(ret), torch.tensor(lens), torch.tensor(l0losses)
        else:
            return torch.stack(ret), torch.tensor(lens)

    def encode(self, S_list, **kwargs):
        if type(self.model.tokenizer) is BertTokenizer:
            indexed = [self.model.tokenizer.encode(
                s, add_special_tokens=True) for s in S_list]
        else:
            if not self.remove_sos_and_eos:
                indexed = [self.model.tokenizer.encode(
                    "<SOS>" + s + "<EOS>").ids for s in S_list]
            else:
                indexed = [self.model.tokenizer.encode(
                   s).ids for s in S_list]

        lengths = [len(i) for i in indexed]
        X, X_lens, sort_idx = self._prepare_batch(indexed, lengths)
        if 'return_l0loss' in kwargs:
            (encoded, out_lens), l0loss = self.model.encode(X, X_lens, **kwargs)
            outs, lens, l0loss = self._undo_batch(
                encoded, out_lens, sort_idx, l0loss)
            result = outs, lens
            #l0loss = lengths - l0loss
            return result, l0loss, lengths
        else:
            encoded, lengths = self.model.encode(X, X_lens, **kwargs)
            result = self._undo_batch(encoded, lengths, sort_idx)
        # Since _prepare_batch sorts by length, we will need to undo this.
        return result


class AEDecoder(Decoder):
    def __init__(self, config):
        super(AEDecoder, self).__init__()
        self.device = config["device"]
        self.model = get_autoencoder(config)
        self.beam_width = config["beam_width"]

    def _prepare_batch(self, indexed, lengths):
        X = pad_sequence([torch.tensor(index_list, device=self.device)
                          for index_list in indexed], batch_first=True, padding_value=0)
        #lengths, idx = torch.sort(torch.tensor(lengths, device=self.device).long(), descending=True)
        # return X[idx], lengths, idx
        lengths = torch.tensor(lengths, device=self.device).long()
        return X, lengths

    def _encode(self, S_list):
        if type(self.model.tokenizer) is BertTokenizer:
            indexed = [self.model.tokenizer.encode(
                s, add_special_tokens=True) for s in S_list]
        else:
            indexed = [self.model.tokenizer.encode(
                "<SOS>" + s + "<EOS>").ids for s in S_list]

        lengths = [len(i) for i in indexed]
        X, X_lens = self._prepare_batch(indexed, lengths)
        return X, X_lens

    def predict(self, S_batch, target_batch=None):
        if self.training:
            target_batch, target_length = self._encode(target_batch)
            out = self.model.decode_training(
                S_batch, target_batch, target_length)
            return out, target_batch
        else:
            return self.model.decode(S_batch, beam_width=self.beam_width)

    def prediction_to_text(self, predictions):
        predictions = [self.model.tokenizer.decode(
            p, skip_special_tokens=True) for p in predictions]
        return predictions
