import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random
from random import choices
from noise import noisy
from torch.nn.functional import pdist
import numpy as np
from torch.nn.modules.distance import CosineSimilarity
import sys
sys.path.append("../")
from emb2emb.hausdorff import hausdorff_similarity
from emb2emb.gmm import prep_gmm_input, make_gmm
from autoencoders.torch_utils import make_mask


class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config

    def encode(self, x, lengths, train=False):
        pass


class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        self.config = config

    def decode(self, x, train=False, actual=None, lengths=None, beam_width=1):
        pass

    def decode_teacher_forcing(self, x, actual, lengths):
        pass


def _hausdorff_sim(encoded, similarity_f="euclidean", differentiable=True, softmax_temp=1.0):

    enc, enc_len = encoded

    # we stored X is in each ith entry, Y is in each (i+1)th entry
    #num_pairs = enc.size(0) / 2
    #assert num_pairs.is_integer()
    #num_pairs = int(num_pairs)
    #X_indices = [i for i in range(0, num_pairs, 2)]
    #Y_indices = [i + 1 for i in range(0, num_pairs, 2)]

    #X = enc[X_indices]
    #Y = enc[Y_indices]

    #X_len = enc_len[X_indices]
    #Y_len = enc_len[Y_indices]

    # print(X.size())
    # print(Y.size())

    # we have to compare each sample in X's batch to every other sample

    def make_mask(A, A_len):
        mask_A = torch.arange(A.size(1), device=A_len.device).unsqueeze(
            0)
        mask_A = mask_A.expand(A_len.size(0), -1)
        mask_A = mask_A < A_len.unsqueeze(1)
        return mask_A

    Y = enc
    Y_len = enc_len
    mask_Y = make_mask(Y, Y_len)

    similarities = torch.zeros((enc.size(0), enc.size(0)), device=enc.device)
    for i in range(enc.size(0)):
        X_sample = enc[i]
        X_sample_len = enc_len[i]
        X_sample = X_sample.unsqueeze(0).expand(
            enc.size(0), -1, -1).contiguous()
        X_sample_len = X_sample_len.unsqueeze(0).expand(enc.size(0))

        mask_X = make_mask(X_sample, X_sample_len)

        result = hausdorff_similarity(X_sample, Y, mask_X=mask_X, mask_Y=mask_Y,
                                      similarity_function=similarity_f, naive=False,
                                      differentiable=differentiable,
                                      softmax_temp=softmax_temp)
        similarities[i, :] = result

    return similarities


class AutoEncoder(nn.Module):
    def __init__(self, encoder, decoder, tokenizer, config):
        super(AutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tokenizer = tokenizer
        self.config = config
        self.adversarial = config.adversarial
        self.use_l0drop = config.use_l0drop
        self.l0_lambda = config.l0_lambda
        self.act = config.act
        self.act_lambda = config.act_lambda
        self.variational = config.variational
        self.denoising = config.denoising
        self.softmax_temp = config.softmax_temp
        self.rae_regularization = config.rae_regularization
        self.close_pairs_regularization = config.close_pairs_regularization
        # self.similarity_f = batch_pairwise_squared_distances if config.simclr_sim == "l2" else lambda x: batch_pairwise_similarity(
        #    x, nn.CosineSimilarity(dim=-1))
        self.similarity_f = lambda x: _hausdorff_sim(
            x, "euclidean" if config.simclr_sim == "l2" else "cosine", differentiable=config.simclr_differentiable,
            softmax_temp=config.simclr_softmax_temp)
        if self.close_pairs_regularization > 0.:
            self.logsoftmax = nn.LogSoftmax(dim=1)
        if self.config.share_embedding:
            self.decoder.embedding = self.encoder.embedding
            self.decoder.input_projection = self.encoder.input_projection
        if self.adversarial:
            self.discriminator = nn.Sequential(
                torch.nn.Linear(config.hidden_size,
                                config.hidden_size, bias=False),
                torch.nn.SELU(),
                torch.nn.Linear(config.hidden_size,
                                config.hidden_size, bias=False),
                torch.nn.SELU(),
                torch.nn.Linear(config.hidden_size, 1, bias=True),
                torch.nn.Sigmoid()
            )
            self.optimD = torch.optim.Adam(
                self.discriminator.parameters(), lr=config.discriminator_lr)

        if self.close_pairs_regularization > 0.:
            if config.simclr_head == "nonlinear":
                self.simclr_head = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                                                 nn.ReLU(),
                                                 nn.Linear(config.hidden_size, config.hidden_size))
            elif config.simclr_head == "linear":
                self.simclr_head = nn.Linear(
                    config.hidden_size, config.hidden_size)
            elif config.simclr_head == "none":
                self.simclr_head = nn.Identity()

    def forward(self, x, lengths):

        # sort the sequence by length
        lengths, idx = torch.sort(lengths, descending=True)
        x = x[idx]
        # remember original sorting
        original_indices = torch.zeros(lengths.size(0))
        for i in range(len(lengths)):
            original_indices[idx[i]] = i
        #original_indices[idx] = torch.arange(0, lengths.size(0))
        original_indices = original_indices.long()

        # denoising
        if self.training and self.denoising:
            x_, lengths_, orig_indices = noisy([self.encoder.config.pad_idx,
                                                self.encoder.config.unk_idx,
                                                self.encoder.config.sos_idx,
                                                self.encoder.config.eos_idx], x, self.config.p_drop)
        else:
            x_, lengths_ = x, lengths

        # x shape:                  (batch, seq_len)
        encoded = self.encoder.encode(x_, lengths_, train=True)

        if self.training and self.denoising:

            if self.use_l0drop:
                (enc, enc_len), l0_loss = encoded
            elif self.act:
                (enc, enc_len), act_cost = encoded
            else:
                enc, enc_len = encoded
            enc = enc[orig_indices]
            enc_len = enc_len[orig_indices]
            if self.use_l0drop:
                l0_loss = l0_loss[orig_indices]
            elif self.act:
                act_cost = act_cost[orig_indices]

            if self.use_l0drop:
                encoded = (enc, enc_len), l0_loss
            elif self.act:
                encoded = (enc, enc_len), act_cost
            else:
                encoded = (enc, enc_len)

        if self.training and self.denoising and self.close_pairs_regularization > 0:
            # if we do both denoising and SimCLR objective, we treat the normal and noisy
            # pairs as positive examples
            encoded_original = self.encoder.encode(x, lengths, train=True)

            enc, enc_len = encoded

            # create vectors for merged outputs
            encoded_new = torch.zeros(enc.size(
                0) * 2, *enc.size()[1:], device=enc.device)
            enc_len_new = torch.zeros(
                enc_len.size(0) * 2, device=enc_len.device)

            # merge original inputs and noisy inputs
            for i in range(0, enc.size(0)):
                # copy embeddings
                encoded_new[2 * i, :] = encoded_original[0][i, :]
                encoded_new[(2 * i) + 1, :] = enc[i, :]

                # copy length
                enc_len_new[2 * i] = encoded_original[1][i]
                enc_len_new[(2 * i) + 1] = enc_len[i]
            encoded = encoded_original  # decode from the original embeddings
            encoded_new = (encoded_new, enc_len_new)

        if self.variational:
            encoded, mean, logv = encoded
        # encoded shape:            (batch, hidden_size)
        elif self.use_l0drop:
            encoded, l0_loss = encoded
        elif self.act:
            encoded, act_cost = encoded

        # add small gaussian noise during training
        if self.training and self.config.gaussian_noise_std > 0.:
            emb, lens = encoded
            emb = emb + \
                torch.randn_like(emb) * self.config.gaussian_noise_std
            encoded = (emb, lens)

        if self.training and self.config.vector_dropout > 0:
            emb, lens = encoded
            drop_probs = torch.ones(emb.size(0), emb.size(1), device=emb.device) * \
                (1. - self.config.vector_dropout)
            mask = torch.bernoulli(drop_probs).unsqueeze(-1).expand_as(emb)
            emb = emb * mask
            encoded = (emb, lens)

        if self.training and self.config.remove_vectors > 0:
            encoded = remove_vectors(self.config.remove_vectors, encoded)

        if self.training and self.config.add_vectors > 0:
            encoded = add_vectors(self.config.add_vectors, encoded)

        if self.config.gmm_vae:
            emb, lens = encoded
            emb_mask = make_mask(emb.size(0), emb.size(1), lens)
            gmm_means, gmm_sigma, gmm_weight = prep_gmm_input(
                emb, emb_mask, False, self.config.gmm_sigma)
            gmm_model = make_gmm(gmm_means, gmm_sigma, gmm_weight)
            emb = sample_gmm(gmm_model, lens)
            encoded = (emb, lens)

        if self.training and self.config.teacher_forcing_batchwise and self.config.teacher_forcing_ratio > random.random():
            decoded_pred = self.decoder.decode_teacher_forcing(
                encoded, x, lengths)

            #decoded_pred_greedy = self.decoder.decode(encoded, train=True, actual=x, lengths=lengths)
            # iseq = #torch.equal(decoded_pred_greedy, decoded_pred)
            #iseq = decoded_pred.isclose(decoded_pred_greedy)
            # print(iseq.float().mean())
            #iseq = iseq.byte().all()
            #assert iseq
        else:
            decoded_pred = self.decoder.decode(
                encoded, train=True, actual=x, lengths=lengths)

        # restore original ordering
        decoded_pred = decoded_pred[original_indices]
        encoded = (encoded[0][original_indices], encoded[1][original_indices])

        # ret:                      (batch, seq_len, classes)
        if self.variational:
            return decoded_pred, mean, logv, encoded
        elif self.use_l0drop:
            return decoded_pred, encoded, l0_loss
        elif self.act:
            return decoded_pred, encoded, act_cost
        if self.adversarial:
            # it's important to detach the encoded embedding before feeding into the
            # discriminator so that when updating the discriminator, it doesn't
            # backprop through the generator
            encoded_det = encoded.detach().clone()
            prior_data = torch.randn_like(encoded)
            return decoded_pred, self.discriminator(encoded), self.discriminator(encoded_det), self.discriminator(prior_data), encoded
        elif self.training and self.denoising and self.close_pairs_regularization > 0:
            # no need to revert the indices of encoded_new, because they are
            # only used for computing the loss
            encoded_new = encoded_new
            return decoded_pred, encoded_new
        else:
            return decoded_pred, encoded

    def encode(self, x, lengths, return_l0loss=False):
        result = self.encoder.encode(x, lengths, reparameterize=False)
        if self.use_l0drop or self.act:
            if not return_l0loss:
                result = result[0]
        return result

    def decode(self, x, beam_width=1):
        return self.decoder.decode(x, beam_width=beam_width)

    def decode_training(self, h, actual, lengths):
        """
        Decoding step to be used for downstream training
        """
        return self.decoder.decode(h, train=True, actual=actual, lengths=lengths)

    def loss(self, predictions, embeddings, labels, reduction="mean"):
        # predictions:  (batch, seq_len, classes)
        # labels:       (batch, seq_len)

        l_rec = F.cross_entropy(
            predictions.reshape(-1, predictions.shape[2]), labels.reshape(-1), ignore_index=0, reduction=reduction)

        outputs = {"l_rec": l_rec}

        # regularize embeddings
        if self.rae_regularization > 0.:
            l_reg = ((embeddings.norm(dim=-1) ** 2) / 2.).mean()
            l_reg = l_reg * self.rae_regularization
            outputs["l_reg"] = l_reg
            l = l_reg + l_rec
            return l
        else:
            l = l_rec

        if self.close_pairs_regularization > 0:

            if self.denoising and not self.training:
                #outputs["l_simclr"] = 0.
                pass
            else:
                l_simclr = self.close_pairs_regularization * \
                    self._loss_for_pairs(embeddings)
                outputs["l_simclr"] = l_simclr
                l = l + l_simclr

        outputs["loss"] = l

        return outputs

    def loss_l0drop(self,
                    output, embeddings, labels, l0_loss):
        losses = self.loss(output, embeddings, labels)
        l0_loss = l0_loss.mean()
        total_loss = losses['loss'] + self.l0_lambda * l0_loss
        losses['loss'] = total_loss
        losses['l0_loss'] = l0_loss
        return losses

    def loss_act(self,
                 output, embeddings, labels, act_cost):
        losses = self.loss(output, embeddings, labels)
        act_cost = act_cost.mean()
        total_loss = losses['loss'] + self.act_lambda * act_cost
        losses['loss'] = total_loss
        losses['act_cost'] = act_cost
        return losses

    def _loss_for_pairs(self, embedded):
        # subsequent examples should be close
        #even = np.array(list(range(0, embedded.size(0), 2)))
        #odd = even + 1

        bsize = embedded[0].size(0)
        device = embedded[0].device

        embedded = self.simclr_head(embedded)
        similarities = self.similarity_f(embedded)
        similarities = similarities.squeeze(0)
        # set the similarity to itself to very small value

        # test code
        #sim_test = torch.zeros((embedded.size(0), embedded.size(0)))
        # for i in range(embedded.size(0)):
        #    for j in range(embedded.size(0)):
        #        cossim = CosineSimilarity(dim = -1)
        #        sim_test[i, j] = cossim(embedded[i,:].unsqueeze(0), embedded[j,:].unsqueeze(0))
        #assert sim_test.equal(similarities)

        diag_indices = np.diag_indices(similarities.size(0), ndim=2)
        similarities[diag_indices] = -10e9

        logprobs = self.logsoftmax((similarities) / self.softmax_temp) * -1
        indices = []
        for i in range(0, bsize, 2):
            indices.append([i + 1])
            indices.append([i])
        indices = torch.tensor(
            np.array(indices), device=device).long()
        logprobs = torch.gather(logprobs, 1, indices)
        return logprobs.mean()

    def loss_variational(self, predictions, embeddings, labels, mu, z_var, lambda_r=1, lambda_kl=1, reduction="mean"):
        recon_loss = F.cross_entropy(
            predictions.reshape(-1, predictions.shape[2]), labels.reshape(-1), ignore_index=0, reduction=reduction)
        raw_kl_loss = torch.exp(z_var) + mu**2 - 1.0 - z_var
        if reduction == "mean":
            kl_loss = 0.5 * torch.mean(raw_kl_loss)
        elif reduction == "sum":
            kl_loss = 0.5 * torch.sum(raw_kl_loss)

        outputs = {}
        outputs["loss"] = lambda_r * recon_loss + lambda_kl * kl_loss
        outputs["l_rec"] = recon_loss
        outputs["kl_loss"] = kl_loss
        return outputs

    def loss_adversarial(self, predictions, embeddings, labels, fake_z_g, fake_z_d, true_z, lambda_a=1):
        r_loss = F.cross_entropy(
            predictions.reshape(-1, predictions.shape[2]), labels.reshape(-1), ignore_index=0, reduction="mean")
        d_loss = (F.binary_cross_entropy(true_z, torch.ones_like(true_z)) +
                  F.binary_cross_entropy(fake_z_d, torch.zeros_like(fake_z_d))) / 2
        g_loss = F.binary_cross_entropy(fake_z_g, torch.ones_like(fake_z_g))
        # we need to update discriminator and generator independently, otherwise
        # we will update the generator to produce better distinguishable embeddings,
        # which we do not want
        outputs = {}
        outputs["loss"] = (r_loss + lambda_a * g_loss)
        outputs["l_rec"] = r_loss
        outputs["d_loss"] = d_loss
        outputs["g_loss"] = g_loss
        return outputs

    def eval(self, x, lengths, teacher_forcing=False, beam_width=1):
        encoded = self.encoder.encode(x, lengths)
        # encoded shape:            (batch, hidden_size)
        if teacher_forcing:
            return self.decoder.decode_teacher_forcing(encoded, x, lengths)
        else:
            return self.decoder.decode(encoded, beam_width=beam_width)


def batch_pairwise_squared_distances(x):
    '''                                                                                              
    Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3         
    Input: x is a bxNxd matrix y is an optional bxMxd matirx                                                             
    Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
    i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2                                                         
    '''
    x = x.unsqueeze(0)
    y = x
    x_norm = (x**2).sum(2).view(x.shape[0], x.shape[1], 1)
    y_t = y.permute(0, 2, 1).contiguous()
    y_norm = (y**2).sum(2).view(y.shape[0], 1, y.shape[1])
    dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
    dist[dist != dist] = 0  # replace nan values with 0
    return -torch.clamp(dist, 0.0, np.inf)


def batch_pairwise_similarity(x, similarity_f):

    A = x.unsqueeze(0).expand(x.size(0), -1, -1)
    B = x.unsqueeze(1).expand(-1, x.size(0), -1)

    return similarity_f(A, B)


if __name__ == "__main__":
    X = torch.randn(64, 32, 16)
    X_len = torch.randint(high=32, low=1, size=(64,))

    added_X, added_X_len = add_vectors(0.0, (X, X_len))

    for i in range(64):
        X_i = X[i, :X_len[i]]
        X_added_i = added_X[i, :X_len[i]]
    #assert torch.equal(X, added_X)
    assert torch.equal(X_len, added_X_len)

    removed_X, removed_X_len = remove_vectors(0.0, (X, X_len))

    #assert torch.equal(X, removed_X)
    for i in range(64):
        X_i = X[i, :X_len[i]]
        X_removed_i = removed_X[i, :X_len[i]]
    assert torch.equal(X_len, removed_X_len)

    #df = remove_vectors(0.0, (X, X_len))
