import torch
from torch import nn
from torch.nn import CosineEmbeddingLoss
from torch.nn.modules.loss import BCELoss

import math
from scipy import special

from emb2emb.hausdorff import _local_hausdorff_similarities, local_bag_losses


class CosineLoss(CosineEmbeddingLoss):

    def __init__(self, **args):
        super(CosineLoss, self).__init__(**args)

    def forward(self, predicted, true):
        sizes = predicted.size()[:-1]
        target = torch.ones(sizes).to(
            predicted.device)
        bsize = predicted.size(0)
        dimsize = predicted.size(-1)

        target = target.view(-1)
        predicted = predicted.view(-1, dimsize)
        true = true.view(-1, dimsize)
        result = super(CosineLoss, self).forward(predicted, true, target)
        result = result.view(bsize, -1, 1)
        return result


def logcmkappox(d, z):  # approximation of LogC(m, k)
    v = d / 2 - 1
    return torch.sqrt((v + 1) * (v + 1) + z * z) - (v - 1) * torch.log(v - 1 + torch.sqrt((v + 1) * (v + 1) + z * z))


class Logcmk(torch.autograd.Function):
    """
    The exponentially scaled modified Bessel function of the first kind
    """
    @staticmethod
    def forward(ctx, k):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        m = 300
        ctx.save_for_backward(k)
        k = k.double()
        # print(k)
        # t1 =
        answer = (m / 2 - 1) * torch.log(k) - torch.log(special.ive(m /
                                                                    2 - 1, k)) - k - (m / 2) * math.log(2 * math.pi)
        answer = answer.float()
        # print(answer)
        return answer

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        k, = ctx.saved_tensors
        m = 300
        # x = -ratio(m/2, k)
        k = k.double()
        x = -((special.ive(m / 2, k)) / (special.ive(m / 2 - 1, k)))
        x = x.float()

        return grad_output * x


class NLLvMF(nn.Module):

    def __init__(self, lambda1=0.02, lambda2=0.1, padding_index=0):
        super(NLLvMF, self).__init__()
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.padding_index = padding_index

    def logcmkappox(k):
        m = 300
        v = m / 2 - 1
        return torch.sqrt((v + 1) * (v + 1) + k * k) - (v - 1) * torch.log(v - 1 + torch.sqrt((v + 1) * (v + 1) + k * k))

    def forward(self, predicted, true):
        loss = 0
        pred = pred.reshape(-1, 300)
        target = target.reshape(-1, 300)
        # Ideally, we should use the real Logcmk after quite a few iters
        for y_hat, y in zip(pred, target):
            # loss += -Logcmk.apply(y_hat.norm()) - y_hat.T@y
            loss += -logcmkappox(y_hat.norm()) - y_hat@(y / y.norm())
        return loss / pred.shape[0]


class BagLoss(nn.Module):

    def forward(self, X, Y, X_len=None, Y_len=None):
        """
        X : the predicted bag [batch_size, max_len, embedding-dim]
        Y : the target bag [batch_size, max_len, embedding-dim]
        X_len : the size of the predicted bag
        Y_len : the size of the predicted bag
        """
        raise NotImplementedError()


class LocalBagLoss(BagLoss):
    def __init__(self, bag_loss_f, detach=False, weighting="window", windowsize=3, weighting_center="input", input_center_factor=1.0):
        super(LocalBagLoss, self).__init__()
        self.detach = detach
        self.weighting = weighting
        self.windowsize = windowsize
        self.weighting_center = weighting_center
        self.bag_loss_f = bag_loss_f
        self.input_center_factor = input_center_factor

    def _get_index(self, X, Y, X_len, Y_len):
        maxlen, _ = X_len.max(dim=0)
        index = torch.arange(maxlen, device=X_len.device)
        return index

    def forward(self, X, Y, X_len=None, Y_len=None, reduce=True):
        """
        X : the predicted bag [batch_size, max_len, embedding-dim]
        Y : the target bag [batch_size, max_len, embedding-dim]
        X_len : the size of the predicted bag
        Y_len : the size of the predicted bag
        """

        mask_Y = torch.arange(Y.size(1), device=Y_len.device).unsqueeze(
            0)
        mask_Y = mask_Y.expand(Y_len.size(0), -1)
        mask_Y = mask_Y < Y_len.unsqueeze(1)

        if self.training and self.weighting == "window" and self.windowsize == 0:

            mask_X = torch.arange(X.size(1), device=X_len.device).unsqueeze(
                0)
            mask_X = mask_X.expand(X_len.size(0), -1)
            mask_X = mask_X < Y_len.unsqueeze(1)
            hausdorff_sims_local, _ = self.bag_loss_f(
                X, Y, mask_X, mask_Y, None)
            # print(hausdorff_sims_local)
            hausdorff_sims = torch.zeros(
                (X.size(0), X.size(1)), device=X.device)
            for i in range(X.size(0)):
                hausdorff_sims[:, Y_len[i] - 1] = hausdorff_sims_local[i]
            # print(hausdorff_sims)
        else:
            hausdorff_sims = local_bag_losses(
                X, Y, mask_Y, self.bag_loss_f, self.detach, max_X_len=-1)

        index = self._get_index(X, Y, X_len, Y_len)
        hausdorff_sims = hausdorff_sims[:, index]

        assert len(hausdorff_sims.size()) == 2

        # we average over the positions
        if reduce:
            weights = _get_weights(self, hausdorff_sims, Y_len, X_len)
            loss = (hausdorff_sims * weights).sum(-1)

            if self.weighting == "uniform":
                # print(hausdorff_sims.size())
                # print(loss.size())
                # print(hausdorff_sims)
                #print(torch.abs(hausdorff_sims.mean(-1) - loss).max(-1))
                assertion = torch.allclose(
                    hausdorff_sims.mean(-1), loss, rtol=1e-04)
                if assertion:
                    pass
                else:
                    assert assertion
        else:
            loss = hausdorff_sims

        # and multiply by -1 in order to make the similarities a loss
        loss = loss
        return loss

# get weights


def _get_weights(loss_f, hausdorff_sims, Y_len, X_len):
    num_vectors = float(hausdorff_sims.size(1))
    if loss_f.weighting == "uniform":
        weights = torch.ones_like(hausdorff_sims) / num_vectors
    elif loss_f.weighting == "window":
        if loss_f.weighting_center == "input":
            center = Y_len - 1
        elif loss_f.weighting_center == "optimum":
            _, center = hausdorff_sims.max(dim=-1)

        weights = torch.ones_like(hausdorff_sims)

        # 1...N
        rangetensor = torch.arange(hausdorff_sims.size(
            1), device=weights.device)
        rangetensor = rangetensor.expand((weights.size(0), -1))

        # mask out the entries that are outside of the window.
        center = center.unsqueeze(1)
        smaller = 1 - (rangetensor < (center - loss_f.windowsize)).float()
        greater = 1 - (rangetensor > (center + loss_f.windowsize)).float()
        weights = weights * smaller
        weights = weights * greater

        # adding some epsilon to avoid division by zero
        weight_sum = (weights.sum(1, keepdim=True) + 0.00001)
        weights = weights / weight_sum
    elif loss_f.weighting == "out_lens":
        weights = torch.ones_like(hausdorff_sims)
        rangetensor = torch.arange(hausdorff_sims.size(
            1), device=weights.device)
        rangetensor = rangetensor.expand((weights.size(0), -1))

        mask = rangetensor == (X_len - 1).unsqueeze(1)
        weights = weights * mask.float()
    elif loss_f.weighting == "uniform_till_input":

        center = Y_len - 1
        weights = torch.ones_like(hausdorff_sims)

        # 1...N
        rangetensor = torch.arange(hausdorff_sims.size(
            1), device=weights.device)
        rangetensor = rangetensor.expand((weights.size(0), -1))

        # mask out the entries that are outside of the window.
        center = center.unsqueeze(1)
        greater = 1 - (rangetensor > (center)).float()
        weights = weights * greater

        # adding some epsilon to avoid division by zero
        weight_sum = (weights.sum(1, keepdim=True) + 0.00001)
        weights = weights / weight_sum

    elif loss_f.weighting == "sumtoone":
        center = Y_len - 1
        weights = torch.ones_like(hausdorff_sims)

        # 1...N
        rangetensor = torch.arange(hausdorff_sims.size(
            1), device=weights.device)
        rangetensor = rangetensor.expand((weights.size(0), -1))

        center = center.unsqueeze(1)
        weights = weights / ((center + 1 - rangetensor) + 0.00001)

        # mask out the entries that are outside of the window.
        greater = 1 - (rangetensor > (center)).float()
        weights = weights * greater

        # adding some epsilon to avoid division by zero
        weight_sum = (weights.sum(1, keepdim=True) + 0.00001)
        weights = weights / weight_sum

    return weights


class HausdorffLoss(BagLoss):

    def __init__(self, differentiable=True, softmax_temp=1.0, detach=True, weighting="window", windowsize=3, weighting_center="input", similarity_function="euclidean"):
        super(HausdorffLoss, self).__init__()
        self.differentiable = differentiable
        self.softmax_temp = softmax_temp
        self.detach = detach
        self.weighting = weighting
        self.windowsize = windowsize
        self.weighting_center = weighting_center
        self.similarity_function = similarity_function

    def _get_index(self, X, Y, X_len, Y_len):
        maxlen, _ = X_len.max(dim=0)
        index = torch.arange(maxlen, device=X_len.device)
        return index

    def forward(self, X, Y, X_len=None, Y_len=None, reduce=True):
        """
        X : the predicted bag [batch_size, max_len, embedding-dim]
        Y : the target bag [batch_size, max_len, embedding-dim]
        X_len : the size of the predicted bag
        Y_len : the size of the predicted bag
        """
        Y_len = Y_len.to(Y.device)
        mask_Y = torch.arange(Y.size(1), device=Y_len.device).unsqueeze(
            0)
        mask_Y = mask_Y.expand(Y_len.size(0), -1)
        mask_Y = mask_Y < Y_len.unsqueeze(1)

        hausdorff_sims = _local_hausdorff_similarities(X, Y, mask_Y, similarity_function=self.similarity_function, naive=False,
                                                       differentiable=self.differentiable, softmax_temp=self.softmax_temp,
                                                       max_X_len=-1, detach=self.detach,
                                                       naive_local=True)

        index = self._get_index(X, Y, X_len, Y_len)
        hausdorff_sims = hausdorff_sims[:, index]

        assert len(hausdorff_sims.size()) == 2

        # we average over the positions
        if reduce:
            weights = _get_weights(self, hausdorff_sims, Y_len, X_len)
            loss = (hausdorff_sims * weights).sum(-1)

            if self.weighting == "uniform":
                assert torch.allclose(hausdorff_sims.mean(-1), loss)
        else:
            loss = hausdorff_sims

        # and multiply by -1 in order to make the similarities a loss
        loss = loss * (-1.)
        return loss


class FlipLoss(nn.Module):

    def __init__(self, baseloss, classifier, lambda_clfloss=0.5,
                 increase_until=10000,
                 *args):
        super(FlipLoss, self).__init__(*args)
        self.baseloss = baseloss
        #self.baseloss = AlignmentLoss()
        # assumed to return logit for binary classification (sigmoid)
        self.classifier = classifier
        self.sigmoid = nn.Sigmoid()
        self.bce = BCELoss()
        self.lambda_clfloss = lambda_clfloss
        self.increase_until = increase_until

        self.i = 0

        for p in self.classifier.parameters():
            p.requires_grad = False

        self.classifier.eval()

    def _get_lambda(self):

        if self.i >= self.increase_until:
            l = self.lambda_clfloss
        else:
            l = (float(self.i) / self.increase_until) * self.lambda_clfloss

        if self.training:
            self.i = self.i + 1

        return l

    def forward(self, predicted, true):

        isbov = isinstance(predicted, tuple)

        if isbov:
            predicted, X_len = predicted[0], predicted[1]

            # make sure the predicted embeddings are zero at invalid positions
            mask = torch.arange(predicted.size(
                1), device=predicted.device)
            mask = mask.unsqueeze(0).unsqueeze(2).expand(
                predicted.size(0), -1, predicted.size(2))
            mask = mask < X_len.unsqueeze(1).unsqueeze(2)
            predicted = predicted * mask

        baseloss = self.baseloss(predicted, true)

        if isbov:
            predicted_label = self.classifier(predicted, X_len)[..., 1]
        else:
            predicted_label = self.classifier(predicted)
        predicted_label = self.sigmoid(predicted_label)

        # we are assuming that the "fake" example was trained to be
        # label '0'
        desired_label = torch.zeros_like(
            predicted_label, device=predicted_label.device)

        clf_loss = self.bce(predicted_label, desired_label)
        l = self._get_lambda()
        clf_loss = l * clf_loss
        baseloss = (1 - l) * baseloss
        loss = clf_loss + baseloss

        return loss


def batch_pairwise_similarity(x, similarity_f):

    A = x.unsqueeze(0).expand(x.size(0), -1, -1)
    B = x.unsqueeze(1).expand(-1, x.size(0), -1)

    return similarity_f(A, B)


class MeanSimilarityLoss(nn.Module):

    def __init__(self):
        super(MeanSimilarityLoss, self).__init__()
        self.cos_loss = CosineLoss(reduction='none')

    def forward(self, predicted, true):

        def getlen(x): return (x.sum(2) != 0.).sum(1)
        p_len = getlen(predicted)
        t_len = getlen(true)
        predicted = predicted.sum(1) / p_len.unsqueeze(1)
        true = true.sum(1) / t_len.unsqueeze(1)

        loss = self.cos_loss(predicted, true).squeeze()
        assert len(loss.size()) == 1 and loss.size(0) == predicted.size(0)
        return loss.unsqueeze(1)


class AlignmentLoss(nn.Module):

    def __init__(self, differentiable=True, directions=["input", "output"], force_target_length=False):
        super(AlignmentLoss, self).__init__()
        self.cos_sim = nn.CosineSimilarity(dim=-1)
        self.differentiable = differentiable
        self.softmax1 = nn.Softmax(dim=-1)
        self.softmax2 = nn.Softmax(dim=-2)
        self.directions = directions
        self.force_target_length = force_target_length

    def similarity_by_axis(self, all_similarities, axis):

        if self.differentiable:
            # print(all_similarities)

            if axis == -1:
                weights = self.softmax1(all_similarities)
            else:
                weights = self.softmax2(all_similarities)
            similarities = weights * all_similarities
            # print(similarities.size())
            # print(similarities)
            similarities = similarities.sum(axis)
        else:
            similarities, _ = all_similarities.max(dim=axis)

        similarities = similarities * -1  # we want to minimize the loss
        return similarities

    def forward(self, predicted, true):
        # predicted = [batch_size, len, embedding_dim]
        # true = [batch_size, len, embedding_dim]

        # for each true one, find the output that is closes to it
        assert predicted.size() == true.size()

        # make a matrix
        maxlen = predicted.size(1)
        embsize = predicted.size(2)
        minuslen = (true.sum(-1) == 0.).sum(1)
        true_len = (maxlen - minuslen).detach()
        true_zeros = (true.sum(-1) == 0.).detach()

        # force predicted length to be the same as true length
        if self.force_target_length:
            predicted = (1 - true_zeros.float()).unsqueeze(-1) * predicted
            pred_len = true_len
            pred_zeros = true_zeros
        else:
            pred_zeros = (predicted.sum(-1) == 0.).detach()
            pred_len = (maxlen - (predicted.sum(-1) == 0.).sum(1)).detach()

        predicted_a = predicted.unsqueeze(
            1).expand(-1, predicted.size(1), -1, -1)
        true_a = true.unsqueeze(2).expand(-1, -1, true.size(1), -1)

        # TODO: how do we mask zeros in the predicted output?
        #mask = ((true_a.sum(-1)) == 0.).transpose(1,2)
        # print(mask)
        #othermask = (predicted_a.sum(-1)) == 0.
        # print(othermask.sum())

        # bring it in right format for cos sim
        predicted_a = predicted_a.reshape(-1, embsize)
        true_a = true_a.reshape(-1, embsize)

        # all_similarities: rows = true_a, columns = predicted_a
        all_similarities = self.cos_sim(predicted_a, true_a)

        # return to matrix format
        all_similarities = all_similarities.view(-1, maxlen, maxlen)

        mask = all_similarities == 0.

        # mask both sides of the symmetrical similarity matrix
        all_similarities = all_similarities - mask * \
            1e10  # - mask.transpose(1,2) * 1e10
        all_similarities = torch.max(torch.full_like(
            all_similarities, -1e10), all_similarities)
        #all_similarities = all_similarities - mask.transpose(1,2) * 1e10
        # print(all_similarities)
        #assert torch.equal(all_similarities,all_similarities.transpose(1,2))

        # for each row/input, find the most similar column/predicted output
        if "input" in self.directions:
            similarities_true = self.similarity_by_axis(all_similarities, -1)
            # print(true_zeros.size())
            # print(similarities_true.size())
            # print(true_len.size())
            similarities_true = (
                (similarities_true * (1 - true_zeros.float())).sum(1)) / true_len
            similarities_input = similarities_true
        else:
            similarities_input = 0.

        if "output" in self.directions:
            # for each column/predicted output, find the most similar row/input
            similarities_output = self.similarity_by_axis(
                all_similarities, -2)
            similarities_output = (similarities_output *
                                   (1 - pred_zeros.float())).sum(1) / pred_len
        else:
            similarities_output = 0.

        similarities = similarities_input + similarities_output

        # output = [batch_size, len]
        similarities = similarities.unsqueeze(-1)
        return similarities


class BacktranslationLoss(nn.Module):

    def __init__(self, baseloss, emb2emb,
                 *args):
        super(BacktranslationLoss, self).__init__(*args)
        self.baseloss = baseloss
        self.emb2emb = emb2emb

    def forward(self, predicted, true):
        predicted = self.emb2emb(predicted)
        baseloss = self.baseloss(predicted, true)
        return baseloss


class CombinedBaseLoss(nn.Module):
    """
    Combine sim loss with BT loss.
    """

    def __init__(self, baseloss, emb2emb,
                 *args):
        super(CombinedBaseLoss, self).__init__(*args)
        self.baseloss = baseloss
        self.emb2emb = emb2emb

    def forward(self, predicted, true):
        sim_loss = self.baseloss(predicted, true)
        back_predicted = self.emb2emb(predicted)
        bt_loss = self.baseloss(back_predicted, true)
        total_loss = sim_loss * bt_loss
        return total_loss.mean()


class SumBaseLoss(nn.Module):
    """
    Combine sim loss with BT loss.
    """

    def __init__(self, baseloss, emb2emb,
                 *args):
        super(SumBaseLoss, self).__init__(*args)
        self.baseloss = baseloss
        self.emb2emb = emb2emb

    def forward(self, predicted, true):
        sim_loss = self.baseloss(predicted, true)
        back_predicted = self.emb2emb(predicted)
        bt_loss = self.baseloss(back_predicted, true)
        return sim_loss + bt_loss
