import sys
from autoencoders.l0drop import compute_g_from_logalpha
sys.path.append('../')

import torch
from emb2emb.classifier import binary_clf_predict, freeze
from torch.nn import BCEWithLogitsLoss, BCELoss


def _local_hausdorff_similarities(X, Y, mask_Y, similarity_function="euclidean", naive=True, differentiable=False, softmax_temp=1.0, max_X_len=-1, detach=True, naive_local=True, alpha=0.5):
    if max_X_len > -1:
        max_len = max_X_len
    else:
        max_len = X.size(1)

    all_hausdorff_similarities = torch.zeros(
        X.size(0), max_len, device=X.device)

    similarities = None

    if detach:
        current_X = X.clone().detach()
    else:
        current_X = X

    for i in range(max_len):

        if detach:
            current_X = current_X.clone().detach()
            current_X[:, i] = X[:, i]

        mask_X = torch.arange(X.size(1), device=X.device).unsqueeze(
            0).expand(X.size(0), -1)
        mask_X = mask_X <= i

        # TODO: implement a smart way of ensuring that we detach AFTER
        # computing the similarity matrix!
        if naive_local:
            res, _ = hausdorff_similarity(current_X, Y, mask_X.detach(), mask_Y, similarity_function, naive,
                                          differentiable, softmax_temp, similarities=similarities, return_similarities=True, alpha=alpha)
        else:
            if detach:
                raise NotImplementedError(
                    "Only naive way is available when trying to detach!")
            res, similarities = hausdorff_similarity(current_X, Y, mask_X.detach(), mask_Y, similarity_function, naive,
                                                     differentiable, softmax_temp, similarities=similarities, return_similarities=True, alpha=alpha)
        all_hausdorff_similarities[:, i] = res

    return all_hausdorff_similarities


def get_local_hausdorff_similarities_function(similarity_function, naive, differentiable, softmax_temp, naive_local, alpha, magnitude_weighting, force_expected_gate_value):

    def local_hausdorff(X, Y, mask_X, mask_Y, info):

        if naive_local:
            res = hausdorff_similarity(X, Y, mask_X, mask_Y, similarity_function, naive,
                                       differentiable,
                                       softmax_temp,
                                       similarities=None,
                                       return_similarities=False,
                                       alpha=alpha,
                                       magnitude_weighting=magnitude_weighting,
                                       force_expected_gate_value=force_expected_gate_value)
            # turn the similarities into a loss by negating the similarity
            return res * (-1.), None
        else:
            res, similarities = hausdorff_similarity(
                X, Y, mask_X, mask_Y, similarity_function, naive, differentiable,
                softmax_temp,
                similarities=info,
                return_similarities=True,
                alpha=alpha,
                magnitude_weighting=magnitude_weighting,
                force_expected_gate_value=force_expected_gate_value)
            # turn the similarities into a loss by negating the similarity
            return res * (-1.), similarities

    return local_hausdorff


def get_local_classifier_loss(bag_classifier, target=1.0, params=None, free_bits=None):

    loss_f = BCEWithLogitsLoss(reduction='none')

    if free_bits is not None:
        loss_freebits = BCELoss(reduction='none')
    bag_classifier.eval()
    freeze(bag_classifier)

    def classifier_loss(X, Y, mask_X, mask_Y, info):

        nonlocal bag_classifier
        nonlocal target

        X_len = mask_X.sum(-1)
        preds = binary_clf_predict(bag_classifier, (X, X_len), params)

        targets = torch.full(
            size=(X.size(0), 1), fill_value=target, device=X.device)
        loss = loss_f(preds, targets)

        loss = loss.squeeze()

        if free_bits:
            # we want to retain only the ones that are not below the target
            # already

            target_loss = loss_freebits(torch.ones_like(preds) * free_bits, targets).squeeze()
            loss = torch.maximum(target_loss, loss)

        # print(loss)
        return loss, None

    return classifier_loss


def get_local_regression_loss(bag_classifier):

    bag_classifier.eval()
    freeze(bag_classifier)

    def classifier_loss(X, Y, mask_X, mask_Y, info):

        nonlocal bag_classifier

        X_len = mask_X.sum(-1)
        preds = binary_clf_predict(bag_classifier, (X, X_len))

        loss = preds
        loss = loss.squeeze()
        return loss, None

    return classifier_loss


def get_weighted_localbagloss_function(func_list, weights):

    def combined_func(X, Y, mask_X, mask_Y, info):
        out_info = []
        res_total = None
        for i, (f, w) in enumerate(zip(func_list, weights)):
            if info is not None:
                info_f = info[i]
            else:
                info_f = None
            res_f, out_info_f = f(X, Y, mask_X, mask_Y, info_f)

            if res_total is None:
                res_total = res_f * w
            else:
                res_total = res_total + res_f * w
            out_info.append(out_info_f)

        return res_total, out_info

    return combined_func


def local_bag_losses(X, Y, mask_Y, bag_loss_function, detach=False, max_X_len=-1):
    if max_X_len > -1:
        max_len = max_X_len
    else:
        max_len = X.size(1)

    all_hausdorff_similarities = torch.zeros(
        X.size(0), max_len, device=X.device)

    info = None

    if detach:
        current_X = X.clone().detach()
    else:
        current_X = X

    for i in range(max_len):

        if detach:
            current_X = current_X.clone().detach()
            current_X[:, i] = X[:, i]

        mask_X = torch.arange(X.size(1), device=X.device).unsqueeze(
            0).expand(X.size(0), -1)
        mask_X = mask_X <= i

        res, _ = bag_loss_function(
            current_X, Y, mask_X.detach(), mask_Y, info)

        # to be passed to the bag_loss_function
        #similarity_function, naive,
        # differentiable, softmax_temp, similarities=similarities,
        # return_similarities=True)
        all_hausdorff_similarities[:, i] = res

    return all_hausdorff_similarities


def gated_softmax(values, gates, dim=-1):
    """
    Gated softmax takes the gates into account for the softmax computation.

    values: [batch_size, seq_len, seq_len]
    gates: [batch_size, seq_len]
    """
    values = values.exp() * gates.unsqueeze(-1)
    normalization = values.sum(dim=dim, keepdim=True)
    weights = values / (normalization + 0.00001)
    return weights


def hausdorff_similarity(X, Y, mask_X=None, mask_Y=None,
                         similarity_function="euclidean",
                         naive=True,
                         differentiable=False,
                         softmax_temp=1.0,
                         similarities=None,
                         return_similarities=False,
                         alpha=0.5,
                         magnitude_weighting=None,
                         force_expected_gate_value=-1):
    """

    Parameters:
    X : [batch_size, num_vectors_X, embedding size]
    Y : [batch_size, num_vectors_Y, embedding size]
    distance_function : either 'euclidean' or 'cosine'
    naive : if true, compute the pairwise similarities in a naive way
    differentiable : if true, use the differentiable version of hausdorff similarity
    softmax_temp: temperature to use with the differentiable version of hausdorff. the closer to zero, the closer the approximation to real max.
    similarities : if provided, use this similarity matrix instead of computing a new one 
    """

    if magnitude_weighting is not None:
        if magnitude_weighting == "norm":
            Y_gates = Y.norm(dim=-1)
            X_gates = X.norm(dim=-1)
        elif magnitude_weighting == "gates":
            X_gates_logalpha = X[..., -1].contiguous()
            X_gates = compute_g_from_logalpha(X_gates_logalpha, epsilon=0.1)

            if force_expected_gate_value > -1:
                X_gates = X_gates + (mask_X * 0.000001)
                X_len = mask_X.sum(1, keepdim=True)
                expected_num_gates = torch.maximum(
                    force_expected_gate_value * X_len, torch.full_like(X_len, 1.0))
                X_gates = (X_gates / X_gates.sum(1, keepdim=True)) * \
                    (expected_num_gates)

            X = X[..., :-1].contiguous()
            Y_gates_logalpha = Y[..., -1].contiguous()
            Y_gates = compute_g_from_logalpha(Y_gates_logalpha, epsilon=0.1)
            Y = Y[..., :-1].contiguous()

    if similarities is None:
        # compute pairwise similarities
        if naive:
            similarities = _pairwise_similarities_naive(
                X, Y, similarity_function)
        else:
            similarities = _pairwise_similarities(X, Y, similarity_function)

    # distances : [batch_size, num_vectors_X, num_vectors_Y]

    # apply mask
    if mask_X is not None and mask_Y is not None:
        mask_X_extended = mask_X.unsqueeze(2).expand_as(similarities).float()
        mask_Y_extended = mask_Y.unsqueeze(1).expand_as(similarities).float()
        masked_similarities = similarities * mask_X_extended - \
            (1 - mask_X_extended) * 1e10
        masked_similarities = masked_similarities * mask_Y_extended - \
            (1 - mask_Y_extended) * 1e10
        masked_similarities = torch.max(torch.full_like(
            masked_similarities, -1e10), masked_similarities)
    else:
        masked_similarities = similarities

    # compute maxes
    if not differentiable:
        X_similarities, _ = masked_similarities.max(dim=2)
        Y_similarities, _ = masked_similarities.max(dim=1)

    else:

        if magnitude_weighting:
            Y_weights = gated_softmax(
                masked_similarities / softmax_temp, Y_gates, dim=2)
        else:
            Y_weights = torch.nn.functional.softmax(
                masked_similarities / softmax_temp, dim=2)

        if mask_X is not None:
            Y_weights = Y_weights * mask_Y_extended

        X_similarities = (masked_similarities * Y_weights).sum(dim=2)

        if magnitude_weighting:
            X_weights = gated_softmax(
                masked_similarities / softmax_temp, X_gates, dim=1)
        else:
            X_weights = torch.nn.functional.softmax(
                masked_similarities / softmax_temp, dim=1)

        if mask_Y is not None:
            X_weights = X_weights * mask_X_extended
        Y_similarities = (masked_similarities * X_weights).sum(dim=1)

    # X/Y_similarities : [batch_size, num_X/Y_vectors]

    def mean_with_mask(T, m):
        if m is not None:
            num_vectors = m.sum(1).float()
            return (T * m).sum(1) / num_vectors
        else:
            return T.mean(1)

    def mean_with_weights(T, m):
        if m is not None:
            num_vectors = m.sum(1).float()
            return (T * m).sum(1) / num_vectors
        else:
            return T.mean(1)

    if not magnitude_weighting:
        align_X = mean_with_mask(X_similarities, mask_X)
        align_Y = mean_with_mask(Y_similarities, mask_Y)

    else:
        align_X = mean_with_mask(X_similarities, mask_X * X_gates)
        align_Y = mean_with_mask(Y_similarities, mask_Y * Y_gates)

    result = alpha * align_X + (1 - alpha) * align_Y

    if return_similarities:
        return result, similarities
    else:
        return result


def _pairwise_similarities_naive(X, Y, distance_function):

    num_vectors_X = X.size(1)
    num_vectors_Y = Y.size(1)

    if distance_function == "cosine":
        cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    distances = torch.zeros((X.size(0), num_vectors_X, num_vectors_Y))

    for b in range(X.size(0)):

        for i in range(num_vectors_X):

            for j in range(num_vectors_Y):

                x_b = X[b, i, :]
                y_b = Y[b, j, :]

                if distance_function == "euclidean":

                    d = (x_b - y_b).norm()
                elif distance_function == "cosine":

                    d = - cosine_sim(x_b.unsqueeze(0), y_b.unsqueeze(0))

                # take the inverse to make it similarity
                distances[b, i, j] = -d

    return distances


def _pairwise_similarities(X, Y, distance_function):
    if distance_function == "euclidean":
        result = torch.cdist(X, Y) * -1
    elif distance_function == "cosine":

        cos_sim = torch.nn.CosineSimilarity(dim=-1)

        X_len = X.size(1)
        Y_len = Y.size(1)
        embsize = X.size(2)

        X = X.unsqueeze(
            1).expand(-1, Y_len, -1, -1)
        Y = Y.unsqueeze(2).expand(-1, -1, X_len, -1)

        # bring it in right format for cos sim
        X = X.reshape(-1, embsize)
        Y = Y.reshape(-1, embsize)

        # all_similarities: rows = true_a, columns = predicted_a
        all_similarities = cos_sim(X, Y)

        # return to matrix format
        all_similarities = all_similarities.view(-1, Y_len, X_len)
        all_similarities = all_similarities.transpose(1, 2)

        result = all_similarities

    return result


'''
        self.cos_sim = nn.CosineSimilarity(dim=-1)
        self.differentiable = differentiable
        self.softmax1 = nn.Softmax(dim=-1)
        self.softmax2 = nn.Softmax(dim=-2)
        self.directions = directions
        self.force_target_length = force_target_length

    def similarity_by_axis(self, all_similarities, axis):

        if self.differentiable:
            # print(all_similarities)

            if axis == -1:
                weights = self.softmax1(all_similarities)
            else:
                weights = self.softmax2(all_similarities)
            similarities = weights * all_similarities
            # print(similarities.size())
            # print(similarities)
            similarities = similarities.sum(axis)
        else:
            similarities, _ = all_similarities.max(dim=axis)

        similarities = similarities * -1  # we want to minimize the loss
        return similarities

    def forward(self, predicted, true):
        # predicted = [batch_size, len, embedding_dim]
        # true = [batch_size, len, embedding_dim]

        # for each true one, find the output that is closes to it
        assert predicted.size() == true.size()

        # make a matrix
        maxlen = predicted.size(1)
        embsize = predicted.size(2)
        minuslen = (true.sum(-1) == 0.).sum(1)
        true_len = (maxlen - minuslen).detach()
        true_zeros = (true.sum(-1) == 0.).detach()

        # force predicted length to be the same as true length
        if self.force_target_length:
            predicted = (1 - true_zeros.float()).unsqueeze(-1) * predicted
            pred_len = true_len
            pred_zeros = true_zeros
        else:
            pred_zeros = (predicted.sum(-1) == 0.).detach()
            pred_len = (maxlen - (predicted.sum(-1) == 0.).sum(1)).detach()

        predicted_a = predicted.unsqueeze(
            1).expand(-1, predicted.size(1), -1, -1)
        true_a = true.unsqueeze(2).expand(-1, -1, true.size(1), -1)

        # TODO: how do we mask zeros in the predicted output?
        #mask = ((true_a.sum(-1)) == 0.).transpose(1,2)
        # print(mask)
        #othermask = (predicted_a.sum(-1)) == 0.
        # print(othermask.sum())

        # bring it in right format for cos sim
        predicted_a = predicted_a.reshape(-1, embsize)
        true_a = true_a.reshape(-1, embsize)

        # all_similarities: rows = true_a, columns = predicted_a
        all_similarities = self.cos_sim(predicted_a, true_a)

        # return to matrix format
        all_similarities = all_similarities.view(-1, maxlen, maxlen)

        mask = all_similarities == 0.

        # mask both sides of the symmetrical similarity matrix
        all_similarities = all_similarities - mask * \
            1e10  # - mask.transpose(1,2) * 1e10
        all_similarities = torch.max(torch.full_like(
            all_similarities, -1e10), all_similarities)
        #all_similarities = all_similarities - mask.transpose(1,2) * 1e10
        # print(all_similarities)
        #assert torch.equal(all_similarities,all_similarities.transpose(1,2))

        # for each row/input, find the most similar column/predicted output
        if "input" in self.directions:
            similarities_true = self.similarity_by_axis(all_similarities, -1)
            # print(true_zeros.size())
            # print(similarities_true.size())
            # print(true_len.size())
            similarities_true = (
                (similarities_true * (1 - true_zeros.float())).sum(1)) / true_len
            similarities_input = similarities_true
        else:
            similarities_input = 0.

        if "output" in self.directions:
            # for each column/predicted output, find the most similar row/input
            similarities_output = self.similarity_by_axis(
                all_similarities, -2)
            similarities_output = (similarities_output *
                                   (1 - pred_zeros.float())).sum(1) / pred_len
        else:
            similarities_output = 0.

        similarities = similarities_input + similarities_output

        # output = [batch_size, len]
        similarities = similarities.unsqueeze(-1)
        return similarities

'''
