"""
Module to train emb2emb and the baseline model.
"""
import torch
from torch import nn
from random import choices
from emb2emb.utils import fast_gradient_iterative_modification
from torch.nn.utils.rnn import pad_sequence
import time
from emb2emb.classifier import BoVBinaryClassifier
from emb2emb.utils import Namespace
from emb2emb.architectures import BovToBovMapping, BovIdentity, BovOracle,\
    SimpleBovMapping
from emb2emb.losses import FlipLoss, HausdorffLoss, LocalBagLoss
from autoencoders.l0drop import compute_g_from_logalpha

MODE_EMB2EMB = "emb2emb"
MODE_SEQ2SEQ = "seq2seq"
MODE_FINETUNEDECODER = "finetune_decoder"
MODE_SEQ2SEQFREEZE = "seq2seq_freeze"


class Emb2EmbTrainer(nn.Module):
    """
    """

    def __init__(self, encoder, decoder, emb2emb, loss_fn, mode, embedding_dim=16, gaussian_noise_std=0., adversarial_regularization={}, fast_gradient_iterative_modification=False, binary_classifier=None, fgim_decay=1.0, fgim_threshold=0.001, fgim_customloss=False, fgim_start_at_y=False, predict_done=False, tokenwise=False, end_of_sequence_epsilon=0.0001, use_end_of_sequence_vector=False, teacher_forcing=0., learned_positional_embeddings=False, force_eval_mode=False, unaligned=False, track_input_output_distance=False, select_input_length=False, gated_vectors=False):
        super(Emb2EmbTrainer, self).__init__()
        self.encoder = encoder
        self.gated_vectors = gated_vectors
        self.unaligned = unaligned
        self.decoder = decoder
        self.emb2emb = emb2emb
        self.loss_fn = loss_fn
        self.fgim_decay = fgim_decay
        self.fgim_threshold = fgim_threshold
        self.fgim_customloss = fgim_customloss
        self.fgim_start_at_y = fgim_start_at_y
        self.change_mode(mode)
        self.track_input_output_distance = track_input_output_distance
        self.gaussian_noise_std = gaussian_noise_std
        self.adversarial_regularization = adversarial_regularization
        self.iterations = 0
        self.fast_gradient_iterative_modification = fast_gradient_iterative_modification
        self.binary_classifier = binary_classifier
        self.total_time_fgim = 0.
        self.total_emb2emb_time = 0.
        self.total_inference_time = 0.
        self.test_out_lens_sum = 0.
        self.bovmapping = isinstance(
            emb2emb, BovToBovMapping) or isinstance(emb2emb, BovIdentity) or isinstance(emb2emb, BovOracle) or isinstance(emb2emb, SimpleBovMapping)
        self.use_end_of_sequence_vector = use_end_of_sequence_vector
        self.select_input_length = select_input_length
        self.end_of_sequence_vector = torch.randn(
            (embedding_dim), requires_grad=True)
        self.end_of_sequence_epsilon = end_of_sequence_epsilon
        self.tokenwise = tokenwise
        self.predict_done = predict_done
        self.teacher_forcing = teacher_forcing
        self.force_eval_mode = force_eval_mode
        self.in_output_distance_mean = 0.
        self.in_output_distance_min = 0.
        self.in_output_distance_batchsize = 0
        if self.bovmapping and self.predict_done:
            self.is_done_clf = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, 2))
            self.is_done_loss = nn.CrossEntropyLoss(reduction='none')

        if mode in [MODE_FINETUNEDECODER, MODE_EMB2EMB, MODE_SEQ2SEQFREEZE]:
            for p in self.encoder.parameters():
                p.requires_grad = False

        if mode in [MODE_FINETUNEDECODER]:
            for p in self.emb2emb.parameters():
                p.requires_grad = False

        if mode in [MODE_EMB2EMB, MODE_SEQ2SEQFREEZE]:
            for p in self.decoder.parameters():
                p.requires_grad = False

        if self.adversarial_regularization:
            if not self.bovmapping or self.tokenwise or adversarial_regularization["discriminate_moments"]:
                hidden = adversarial_regularization["critic_hidden_units"]
                critic_layers = [nn.Linear(adversarial_regularization["critic_input_dim"],
                                           hidden),
                                 nn.ReLU()]
                for _ in range(adversarial_regularization["critic_hidden_layers"]):
                    critic_layers.append(nn.Linear(hidden, hidden))
                    critic_layers.append(nn.ReLU())
                critic_layers.append(nn.Linear(hidden, 2))
                self.critic = [nn.Sequential(*critic_layers)]
            else:
                config = Namespace()
                config.n_layers = adversarial_regularization["critic_hidden_layers"]
                config.heads = 4
                config.hidden_size = 16
                config.input_dim = adversarial_regularization["critic_hidden_units"]
                config.embedding_dim = adversarial_regularization["critic_input_dim"]
                config.learned_positional_embeddings = learned_positional_embeddings
                self.critic = [BoVBinaryClassifier(config)]
            print(self.critic[0])
            self.adversarial_discriminate_moments = adversarial_regularization[
                "discriminate_moments"]
            self.real_data = adversarial_regularization["real_data"]
            self.adversarial_remove_last = adversarial_regularization["adversarial_remove_last"]
            self.critic_loss = nn.CrossEntropyLoss(reduction='none')
            self.critic_optimizer = torch.optim.Adam(
                self._get_critic().parameters(), lr=adversarial_regularization["critic_lr"])
            dev = adversarial_regularization["device"]
            self._get_critic().to(dev)
            self.critic_loss.to(dev)

            # flags to handle which parts to train
            self.adversarial_reconstruction_weight = adversarial_regularization[
                "adversarial_reconstruction_weight"]
            self.joint = adversarial_regularization["joint"]
            self.compute_task_loss = False
            self.train_critic = True  # start with training the critic first
            self.compute_critic_loss = False
            self.round_robin_critic = adversarial_regularization["critic_rounds"]
            self.round_robin_task = adversarial_regularization["task_rounds"]
            self.round_robin_adversarial = adversarial_regularization["adversarial_rounds"]
            self.joint_reconstruction_adversarial = adversarial_regularization[
                "joint_reconstruction_adversarial"]

    def _get_critic(self):
        return self.critic[0]

    def change_mode(self, new_mode):
        if not new_mode in [MODE_EMB2EMB, MODE_SEQ2SEQ, MODE_FINETUNEDECODER, MODE_SEQ2SEQFREEZE]:
            raise ValueError("Invalid mode.")
        self.mode = new_mode

    def _decode(self, output_embeddings, target_batch=None, Y_embeddings=None):

        if self.gated_vectors:
            # in this case we assume the last dimension of the vector to hold
            # log_alpha, and we need to compute the actual gate values

            # make sure we are actually dealing with embeddings and lengths
            # here
            assert type(output_embeddings) is tuple

            o_embs = output_embeddings[0]
            log_alpha = o_embs[..., -1].contiguous()
            g = compute_g_from_logalpha(log_alpha, 0.1)
            o_embs = torch.cat(
                [o_embs[..., :-1].contiguous(),
                 g.unsqueeze(-1),
                 o_embs[..., -1].contiguous().unsqueeze(-1)], dim=-1).contiguous()
            output_embeddings = (o_embs, output_embeddings[1])

        if self.mode == MODE_EMB2EMB or not self.training:

            if self.fast_gradient_iterative_modification:
                # Fast Gradient Iterative Modification
                weights = [10e0, 10e1, 10e2, 10e3]
                start_time = time.time()
                output_embeddings = fast_gradient_iterative_modification(output_embeddings,
                                                                         self.binary_classifier,
                                                                         decay_factor=self.fgim_decay,
                                                                         t=self.fgim_threshold,
                                                                         weights=weights,
                                                                         custom_loss=self.compute_loss if self.fgim_customloss else None,
                                                                         Y_embeddings=Y_embeddings if self.fgim_customloss else None,
                                                                         start_at_y=self.fgim_start_at_y)
                self.total_time_fgim = self.total_time_fgim + \
                    (time.time() - start_time)

            outputs = self.decoder(output_embeddings)
            return outputs

        elif self.mode in [MODE_SEQ2SEQ, MODE_FINETUNEDECODER, MODE_SEQ2SEQFREEZE]:
            outputs, targets = self.decoder(
                output_embeddings, target_batch=target_batch)
            vocab_size = outputs.size(-1)
            outputs = outputs.view(-1, vocab_size)
            targets = targets.view(-1)
            return outputs, targets
        else:
            raise ValueError(
                "Undefined behavior for encoding in mode " + self.mode)

    def _encode(self, S_batch):
        if self.mode in [MODE_EMB2EMB, MODE_FINETUNEDECODER, MODE_SEQ2SEQ, MODE_SEQ2SEQFREEZE]:

            if self.mode in [MODE_EMB2EMB, MODE_FINETUNEDECODER, MODE_SEQ2SEQFREEZE]:
                # make sure the encoder is in eval mode in the cases it is
                # frozen anyway
                self.encoder.eval()
            embeddings = self.encoder(S_batch)

            if self.gated_vectors:
                # if everything is gated, we need to remove the actual gate
                # value, and only retain the alpha value
                assert type(embeddings) is tuple
                in_embs = embeddings[0]
                vec = in_embs[..., :-2].contiguous()
                log_alpha = in_embs[..., -1].contiguous()
                in_embs = torch.cat(
                    [vec, log_alpha.unsqueeze(-1)], dim=-1).contiguous()
                embeddings = (in_embs, embeddings[1])

            if self.bovmapping and self.use_end_of_sequence_vector:
                # add an end-of-sequence-vector
                s_emb, s_len = embeddings
                embeddings = torch.zeros((s_emb.size(0), s_emb.size(1) + 1,
                                          s_emb.size(2)), device=s_emb.device)
                embeddings[:, :-1, :] = s_emb
                # print(self.end_of_sequence_vector)
                for i in range(s_emb.size(0)):
                    embeddings[i, s_len[i], :] = self.end_of_sequence_vector.to(
                        embeddings.device)
                s_len = (s_len + 1).to(s_emb.device)
                embeddings = (embeddings, s_len)
        else:
            raise ValueError(
                "Undefined behavior for encoding in mode " + self.mode)
        return embeddings

    def _train_critic(self, real_embeddings,
                      real_length,
                      generated_embeddings,
                      generated_length):
        self._get_critic().train()

        generated_length = generated_length.to(generated_embeddings.device)
        real_length = real_length.to(real_embeddings.device)
        # print(generated_length)
        # print(real_length)

        # need to detach from the current computation graph, because critic has
        # its own computation graph
        real_embeddings = real_embeddings.detach().clone()
        generated_embeddings = generated_embeddings.detach().clone()

        # get predictions from critic
        # all_embeddings = torch.cat(
        #    [real_embeddings, generated_embeddings], dim=0)
        # compute critic loss for real embeddings
        if self.bovmapping and not (self.tokenwise or self.adversarial_discriminate_moments):
            critic_logits = self._get_critic()(real_embeddings, real_length)
            real_length = torch.ones_like(
                real_length, device=real_embeddings.device)
        else:
            critic_logits = self._get_critic()(real_embeddings)
        bsize = real_embeddings.shape[0]
        max_len = critic_logits.shape[1]
        true_labels = torch.ones(
            (bsize * max_len), device=real_embeddings.device, dtype=torch.long)
        loss = self.critic_loss(critic_logits.view(-1, 2), true_labels)
        mask = self._make_mask(
            bsize, max_len, real_length)
        loss = loss.view(bsize, max_len)
        loss = (loss * mask).mean()

        if self.bovmapping and not (self.tokenwise or self.adversarial_discriminate_moments):
            critic_logits = self._get_critic()(generated_embeddings, generated_length)
            real_length = torch.ones_like(
                generated_length, device=generated_embeddings.device)
        else:
            critic_logits = self._get_critic()(generated_embeddings)
        bsize = generated_embeddings.shape[0]
        max_len = critic_logits.shape[1]
        false_labels = torch.zeros(
            (bsize * max_len), device=generated_embeddings.device, dtype=torch.long)
        loss2 = self.critic_loss(critic_logits.view(-1, 2), false_labels)
        mask = self._make_mask(bsize, max_len, generated_length)
        loss2 = loss2.view(bsize, max_len)
        loss = loss + (loss2 * mask).mean()

        # train critic
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()

        return loss

    def _make_mask(self, bsize, max_lens, lens):
        mask = torch.arange(max_lens, device=lens.device)
        mask = mask.unsqueeze(0).expand(bsize, -1)
        mask = mask < lens.unsqueeze(1)
        return mask

    def _test_critic(self, embeddings, out_lens):
        self._get_critic().eval()
        out_lens = out_lens.to(embeddings.device)

        # with torch.no_grad():
        # do not detach embeddings, because we need to propagate through critic
        # within the same computation graph

        if self.bovmapping and not (self.tokenwise or self.adversarial_discriminate_moments):
            critic_logits = self._get_critic()(embeddings, out_lens)
            out_lens = torch.ones_like(out_lens, device=embeddings.device)
        else:
            critic_logits = self._get_critic()(embeddings)

        bsize = embeddings.shape[0]
        max_len = critic_logits.shape[1]
        labels = torch.zeros(
            (bsize * max_len), device=embeddings.device, dtype=torch.long)
        loss = self.critic_loss(critic_logits.view(-1, 2), labels)
        loss = loss.view(bsize, max_len)
        mask = self._make_mask(bsize, max_len, out_lens)
        loss = loss * mask.float()
        return loss.mean()

    def _adversarial_training(self, loss, output_embeddings, out_lens, Y_embeddings, Y_lens):

        if self.bovmapping:
            with torch.no_grad():
                out_lens = self._find_end_of_sequence(output_embeddings)

        if self.bovmapping and self.adversarial_remove_last:
            out_lens = torch.maximum(torch.ones_like(
                out_lens, device=out_lens.device), out_lens - 1)
            Y_lens = torch.maximum(torch.ones_like(
                Y_lens, device=Y_lens.device), Y_lens - 1)

        if self.bovmapping and self.adversarial_discriminate_moments:

            def compute_means(embs, lens):
                mask = self._make_mask(embs.size(0), embs.size(1), lens)
                #lens = lens.to(embs.device)
                #mask = mask.to(embs.device)
                embs = (embs * mask.unsqueeze(2)).sum(
                    dim=1, keepdim=True) / lens.unsqueeze(1)
                lens = torch.ones_like(lens, device=lens.device)
                return embs, lens

            output_embeddings, out_lens = compute_means(
                output_embeddings, out_lens)
            Y_embeddings, Y_lens = compute_means(Y_embeddings, Y_lens)

        if self.train_critic or self.joint:
            train_critic_loss = self._train_critic(
                Y_embeddings, Y_lens, output_embeddings, out_lens)
        else:
            train_critic_loss = torch.tensor(0.)

        if not self.compute_task_loss and not self.joint and not (self.joint_reconstruction_adversarial and self.compute_critic_loss):
            loss = loss * 0.

        task_loss = loss.clone()
        if self.compute_critic_loss or self.joint or (self.compute_task_loss and self.joint_reconstruction_adversarial):
            critic_loss = self._test_critic(output_embeddings, out_lens)
        else:
            critic_loss = torch.tensor(0.)

        # we want to fool the critic, i.e., we want to its loss to be high =>
        # subtract adversarial loss
        loss = loss - self.adversarial_reconstruction_weight * critic_loss

        self.iterations += 1
        if self.train_critic and (self.iterations % self.round_robin_critic) == 0:
            self.compute_task_loss = False
            self.train_critic = False
            self.compute_critic_loss = True
        elif self.compute_critic_loss and (self.iterations % self.round_robin_adversarial) == 0:
            self.compute_task_loss = True
            self.train_critic = False
            self.compute_critic_loss = False
        elif self.compute_task_loss and (self.iterations % self.round_robin_task) == 0:
            self.compute_task_loss = False
            self.train_critic = True
            self.compute_critic_loss = False

        return loss, task_loss, critic_loss, train_critic_loss

    def compute_emb2emb(self, Sx_batch, Y=None):
        # encode input
        X_embeddings, X_lens = self._encode(Sx_batch)

        # add noise to input for regularization?
        if self.training and self.gaussian_noise_std > 0.:
            X_embeddings = X_embeddings + \
                torch.randn_like(X_embeddings) * self.gaussian_noise_std

        # emb2emb step
        if not self.training:  # measure the time it takes to run through emb2emb, but only at inference time
            s_time = time.time()

        if self.bovmapping:
            output_embeddings, out_lens = self.emb2emb(
                X_embeddings, X_lens, Y=Y, teacher_forcing=self.teacher_forcing)
        else:
            output_embeddings = self.emb2emb(X_embeddings)
            out_lens = X_lens

        if not self.training:
            self.total_emb2emb_time = self.total_emb2emb_time + \
                (time.time() - s_time)

        return output_embeddings, X_embeddings, out_lens

    def compute_loss(self, output_embeddings, out_lens, Y_embeddings):

        # print(output_embeddings)
        Y_embeddings, Y_len = Y_embeddings
        out_lens, Y_len = out_lens.to(
            output_embeddings.device), Y_len.to(output_embeddings.device)
        target = torch.zeros_like(output_embeddings)
        min_emb_len = min(output_embeddings.size(1), Y_embeddings.size(
            1))
        target[:, :min_emb_len] = Y_embeddings[:, :min_emb_len]

        if isinstance(self.loss_fn, FlipLoss):

            # make sure that every vector in target with length > 0 is actually
            # zero
            m = self._make_mask(target.size(
                0), target.size(1), Y_len).unsqueeze(-1).expand(-1, -1, target.size(-1))
            target = target * m
            out_lens_after_seq_end = self._find_end_of_sequence(
                output_embeddings)
            loss = self.loss_fn(
                (output_embeddings, out_lens_after_seq_end), target)
        elif _is_localbag_loss(self.loss_fn):
            loss = self.loss_fn(output_embeddings,
                                Y_embeddings, out_lens, Y_len)

        else:
            loss = self.loss_fn(output_embeddings, target)
        # print(loss)

        # mask invalid loss entries
        if len(loss.squeeze().size()) > 1:
            min_len = torch.min(out_lens, Y_len)
            mask = torch.arange(output_embeddings.size(
                1), device=output_embeddings.device)
            mask = mask.unsqueeze(0).expand(output_embeddings.size(0), -1)
            mask = mask < min_len.unsqueeze(1)
            loss = loss.mean(-1) * mask
            loss = loss.sum(dim=1) / min_len
        else:
            mask = None

        loss = loss.mean()

        if self.bovmapping and self.predict_done:
            is_done_prob = self.is_done_clf(
                output_embeddings).view(-1, 2)  # [batch_size * len, 2]
            is_done = torch.arange(output_embeddings.size(
                1), device=output_embeddings.device)
            is_done = is_done.unsqueeze(0).expand(
                output_embeddings.size(0), -1)
            is_done = (is_done == (Y_len - 1).unsqueeze(1)
                       ).view(-1).long()
            is_done_loss = self.is_done_loss(
                is_done_prob, is_done).view(-1, output_embeddings.size(1))

            if mask is not None:
                is_done_loss = is_done_loss * mask
                is_done_loss = is_done_loss.sum(dim=1) / min_len
                is_done_loss = is_done_loss.mean()
            else:
                # the else case here happens when we use localbagloss in the
                # supervised case
                is_done_loss = is_done_loss.mean()

            # print(loss)
            # print(is_done_loss)
            loss = loss + is_done_loss

        if self.adversarial_regularization:

            if self.real_data == "input":
                real_data = Y_embeddings
                real_len = Y_len
            else:
                real_data = self._encode(
                    choices(self.real_data, k=Y_embeddings.size(0)))
                real_data, real_len = real_data

            loss = self._adversarial_training(
                loss, output_embeddings, out_lens, real_data, real_len)
        return loss, out_lens

    def forward(self, Sx_batch, Sy_batch):
        """
        Propagates through the emb2emb framework. Takes as input two lists of
        texts corresponding to the input and outputs. Returns loss (single scalar)
        if in training mode, otherwise returns texts.
        """
        # measure inference time it takes
        if not self.training:
            s_time = time.time()

        if self.unaligned:
            if Sy_batch is None:
                Sy_batch = Sx_batch

        if (self.teacher_forcing > 0. and self.training) or self.unaligned or self.track_input_output_distance:
            Y = self._encode(Sy_batch)
        else:
            Y = None

        output_embeddings, X_embeddings, out_lens = self.compute_emb2emb(
            Sx_batch, Y=Y)

        if self.training:
            # compute loss depending on the mode

            if self.mode == MODE_EMB2EMB:
                if self.teacher_forcing:
                    Y_embeddings = Y
                else:
                    Y_embeddings = self._encode(Sy_batch)

                loss, out_lens = self.compute_loss(
                    output_embeddings, out_lens, Y_embeddings)
                if self.adversarial_regularization:
                    loss, task_loss, critic_loss, train_critic_loss = loss

            elif self.mode in [MODE_SEQ2SEQ, MODE_FINETUNEDECODER, MODE_SEQ2SEQFREEZE]:
                # for training with CE
                outputs, targets = self._decode(
                    (output_embeddings, out_lens), target_batch=Sy_batch)
                loss = self.loss_fn(outputs, targets)

            if self.adversarial_regularization:
                return loss, task_loss, critic_loss, train_critic_loss
            else:
                return loss
        else:
            # return textual output
            if self.bovmapping:
                if isinstance(self.emb2emb, BovIdentity) or isinstance(self.emb2emb, BovOracle):
                    out_lens_at_end = out_lens
                else:
                    out_lens_at_end = self._find_end_of_sequence(
                        output_embeddings, out_lens, Y)

            out_lens_at_end = out_lens_at_end.to(output_embeddings.device)
            self.test_out_lens_sum += out_lens_at_end.sum().item()
            out = self._decode((output_embeddings, out_lens_at_end),
                               Y_embeddings=X_embeddings)
            self.total_inference_time = self.total_inference_time + \
                (time.time() - s_time)

            if self.track_input_output_distance:
                Y_embeddings = self._encode(Sy_batch)
                if _is_localbag_loss(self.loss_fn):

                    Y_len = Y_embeddings[1].to(Y_embeddings[0].device)
                    input_output_distance = self.loss_fn(
                        output_embeddings,
                        Y_embeddings[0], out_lens_at_end, Y_len)

                    losses = self.loss_fn(
                        output_embeddings, Y_embeddings[0], out_lens, Y_len, reduce=False)
                    min_val, _ = losses.min(dim=1)

                    in_output_distance_mean = input_output_distance.sum(
                        0).item()
                    self.in_output_distance_mean = self.in_output_distance_mean + in_output_distance_mean
                    self.in_output_distance_min = self.in_output_distance_min + \
                        min_val.sum(0).item()
                    self.in_output_distance_batchsize = self.in_output_distance_batchsize + \
                        min_val.size(0)

            return out

    def _find_end_of_sequence(self, output_embeddings, out_lens, Y):
        if self.predict_done:
            is_done_prob = self.is_done_clf(
                output_embeddings)  # [batch_size * len, 2]
            is_done = (is_done_prob[:, :, 1] > 0.5).float()
            out_lens = first_nonzero_index(is_done)
        elif self.use_end_of_sequence_vector:
            is_done = (output_embeddings -
                       self.end_of_sequence_vector.to(output_embeddings.device)).norm(dim=-1)
            is_done = (is_done < self.end_of_sequence_epsilon).float()
            out_lens = first_nonzero_index(is_done)
        elif self.select_input_length:
            out_lens = Y[1].to(output_embeddings.device)
        elif _is_localbag_loss(self.loss_fn):

            # we look for the subset where the distance is minimal
            Y_len = Y[1].to(output_embeddings.device)
            hausdorf_similarities = self.loss_fn(
                output_embeddings, Y[0], out_lens, Y_len, reduce=False)
            _, out_lens = hausdorf_similarities.min(dim=1)
            out_lens = out_lens + 1  # need to add one to minimum index
            #print(out_lens - Y_len)
        return out_lens

    def train(self, mode=True):
        r"""Overwriting the method to make sure encoder and decoder are in eval mode."""
        self.training = mode
        for module in self.children():
            module.train(mode)

        # always keep the encoder and decoder in eval mode
        if self.mode == MODE_EMB2EMB and self.force_eval_mode:
            self.encoder.eval()
            self.decoder.eval()


def _is_localbag_loss(loss_fn):
    result = isinstance(loss_fn, HausdorffLoss) or isinstance(
        loss_fn, LocalBagLoss)
    return result


def first_nonzero_index(x):
    cumul_sum = torch.cumsum(x, dim=1)
    out_lens = (cumul_sum <= 0).sum(dim=1).long() + 1
    return out_lens
