import sys
import torch
import torch.distributions as D
sys.path.append('../')


# def kl_gaussians(mu_one, sigma_one, mu_two, sigma_two):
#    nominator = ((mu_one - mu_two) ** 2) + sigma_one ** 2 - sigma_two ** 2
#    denominator = 2 * sigma_two ** 2
#    kl = nominator / denominator + (sigma_one / sigma_two).log()
#    return kl


def kl_gaussians(mu_one, sigma_one, mu_two, sigma_two):
    """
    We assume isotropic multivariate gaussians.
    """
    # the general formula for KL between two multivariate gaussians:
    # 1 / 2 (log det(s2)/det(s1) - n + tr(s2^-1 * s1) + (mu_2 - mu_1)^T s2^-1
    # (mu_2 - mu_1)

    # under the isotropic assumption it simplifies to:
    # 1 / 2 (log 1 - n + n + (mu_2 - mu_1) ^ 2 / sigma^2
    # = (mu_2 - mu_1)^2 / 2*sigma^2

    kl = ((mu_two - mu_one) ** 2) / (2 * (sigma_one ** 2))
    kl = kl.sum(-1)
    assert (kl >= 0.).all()
    return kl


def make_gmm(means, sigma, weights):
    mix = D.Categorical(weights)

    # eye = torch.eye(
    #    means.size(-1)).unsqueeze(0).expand(sigma.size(0), -1, -1)
    # print(eye.size())
    # cov_mat = (eye * sigma[:, 0, 0].unsqueeze(1).unsqueeze(1)
    #           ).unsqueeze(1).expand(-1, means.size(1), -1, -1)
    # print(cov_mat.size())
    # print(cov_mat)
    cov_factor = torch.zeros(
        (means.size(0), means.size(1), means.size(-1), 1), device=means.device)
    cov_diag = sigma
    comp = D.lowrank_multivariate_normal.LowRankMultivariateNormal(
        means, cov_factor, cov_diag)
    # comp = D.multivariate_normal.MultivariateNormal(
    #    means, covariance_matrix=cov_mat)
    gmm = D.MixtureSameFamily(mix, comp)
    return gmm


def sample_gmm(gmm_model, lens):
    gmm_model


def _gmm_kl_variational_naive(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights):
    n = X_means.size(1)
    m = Y_means.size(1)
    bsize = X_means.size(0)

    total_sum = 0.
    for i in range(n):

        pi_i = X_weights[:, i]

        numerator = 0.
        for j in range(n):
            kl = kl_gaussians(
                X_means[:, i], X_sigma[:, i], X_means[:, j], X_sigma[:, j])
            assert kl.size(0) == bsize and len(kl.size()) == 1
            #print("numkl", kl)

            pi_j = X_weights[:, j]
            kl = -1.0 * kl
            kl = kl.exp()
            #print("numklexp", kl)
            numerator = numerator + kl * pi_j
        assert (numerator >= 0.).all()

        denominator = 0.
        for k in range(m):
            kl = kl_gaussians(
                X_means[:, i], X_sigma[:, i], Y_means[:, k], Y_sigma[:, k])
            assert kl.size(0) == bsize and len(kl.size()) == 1
            #print("demkl", kl)
            pi_k = Y_weights[:, k]
            kl = -1.0 * kl
            kl = kl.exp()
            denominator = denominator + kl * pi_k
        denominator = denominator

        assert (denominator >= 0.).all()

        # print(total_sum)
        # print(numerator)
        # print(denominator)

        total_sum = total_sum + pi_i * (numerator.log() - denominator.log())

    return total_sum


def _gmm_kl_variational_fast(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights):
    n = X_means.size(1)
    m = Y_means.size(1)
    bsize = X_means.size(0)
    # print(Y_means.size())
    # print(Y_sigma.size())

    # print(X_means)
    gmm_X = make_gmm(X_means, X_sigma, X_weights)
    gmm_Y = make_gmm(Y_means, Y_sigma, Y_weights)

    def compute_with_loop():
        total_sum = 0.
        for i in range(n):
            q = gmm_X.log_prob(X_means[:, i, :])
            p = gmm_Y.log_prob(X_means[:, i, :])
            summand = X_weights[:, i] * (q - p)
            total_sum = total_sum + summand
        return total_sum

    def compute_without_loop():
        X_means_t = X_means.transpose(0, 1)
        q = gmm_X.log_prob(X_means_t)
        p = gmm_Y.log_prob(X_means_t)

        q = q.transpose(0, 1)
        p = p.transpose(0, 1)

        summands = (q - p) * X_weights
        total_sum = summands.sum(-1)
        return total_sum

    t2 = compute_without_loop()
    return t2


def _gmm_kl_variational(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights, naive=False):
    if naive:
        result = _gmm_kl_variational_naive(
            X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights)
    else:
        result = _gmm_kl_variational_fast(
            X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights)

    return result


def gmm_kl(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights, approximation="variational_lower"):
    """
    Computes (an approximation to) the KL divergence between two GMMs.

    Parameters:
    X_means : [batch_size, num_vectors_X, embedding size]
    Y_means : [batch_size, num_vectors_Y, embedding size]
    X_sigma : [batch_size, num_vectors_X, embedding size]
    Y_sigma : [batch_size, num_vectors_Y, embedding size]
    X_weights : [batch_size, num_vectors_X]
    Y_weights : [batch_size, num_vectors_Y]
    """

    if approximation == "variational_lower":
        result = _gmm_kl_variational(X_means, Y_means, X_sigma, Y_sigma,
                                     X_weights, Y_weights)
    else:
        raise ValueError(
            f"Unknown GMM KL-divergence approximation {approximation}.")

    # print(result)
    assert result.size(0) == X_means.size(0) and len(result.size()) == 1
    # print(result)
    #assert (result >= 0.).all()

    return result


def _combine_gmms(
        X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights):

    # concat components
    M_means = torch.cat([X_means, Y_means], dim=1)
    M_sigma = torch.cat([X_sigma, Y_sigma], dim=1)
    M_weights = torch.cat([X_weights, Y_weights], dim=1) / 2.

    return M_means, M_sigma, M_weights


def gmm_jsd(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights, approximation="variational_lower"):
    """
    Computes (an approximation to) the Jensen-Shannon divergence between two GMMs, which is defined in terms of the KL as:

    JSD(P, Q) = (1/2) * (KL(P, M) + KL (Q, M))
    where M = (1 / 2) * (P + Q).

    Parameters:
    X_means : [batch_size, num_vectors_X, embedding size]
    Y_means : [batch_size, num_vectors_Y, embedding size]
    X_sigma : [batch_size, num_vectors_X, embedding size]
    Y_sigma : [batch_size, num_vectors_Y, embedding size]
    X_weights : [batch_size, num_vectors_X]
    Y_weights : [batch_size, num_vectors_Y]
    """

    # for GMMs P and Q, (P + Q) / 2 is also a GMM with all components from P and Q
    # and their weights divided by 2
    M_means, M_sigma, M_weights = _combine_gmms(
        X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights)

    if approximation == "variational_lower":
        approx_f = _gmm_kl_variational
    else:
        raise ValueError(
            f"Unknown GMM KL-divergence approximation {approximation}.")

    kl_p_m = approx_f(X_means, M_means, X_sigma, M_sigma,
                      X_weights, M_weights)
    kl_q_m = approx_f(Y_means, M_means, Y_sigma, M_sigma,
                      Y_weights, M_weights)

    result = kl_p_m / 2 + kl_q_m / 2
    assert result.size(0) == X_means.size(0) and len(result.size()) == 1

    return result


def gmm_symkl(X_means, Y_means, X_sigma, Y_sigma, X_weights, Y_weights, approximation="variational_lower"):
    """
    Computes (an approximation to) the symmetric KL divergence, defined as:

    SYMKL(P, Q) = (1/2) * (KL(P, Q) + KL (P, Q))

    Parameters:
    X_means : [batch_size, num_vectors_X, embedding size]
    Y_means : [batch_size, num_vectors_Y, embedding size]
    X_sigma : [batch_size, num_vectors_X, embedding size]
    Y_sigma : [batch_size, num_vectors_Y, embedding size]
    X_weights : [batch_size, num_vectors_X]
    Y_weights : [batch_size, num_vectors_Y]
    """

    if approximation == "variational_lower":
        approx_f = _gmm_kl_variational
    else:
        raise ValueError(
            f"Unknown GMM KL-divergence approximation {approximation}.")

    kl_p_q = approx_f(X_means, Y_means, X_sigma, Y_sigma,
                      X_weights, Y_weights)
    kl_q_p = approx_f(Y_means, X_means, Y_sigma, X_sigma,
                      Y_weights, X_weights)

    result = kl_p_q / 2 + kl_q_p / 2
    assert result.size(0) == X_means.size(0) and len(result.size()) == 1

    return result


def prep_gmm_input(A, A_mask, weighting, sigma):
    A_mask = A_mask.float()

    if weighting == "uniform":
        A_means = A
        A_sigma = torch.ones_like(A) * sigma
        A_len = A_mask.sum(1)
        A_weight = A_mask / A_len.unsqueeze(1)
    elif weighting == "model":
        A_means = A[:, :, 1:]
        A_sigma = torch.ones_like(A_means) * sigma
        A_weight = A[:, :, 0]
        A_weight = A_weight - (1 - A_mask) * 1e10
        A_weight = torch.nn.functional.softmax(A_weight, dim=1)
        A_weight = A_weight * A_mask
    elif weighting == "magnitude":
        A_means = A
        A_sigma = torch.ones_like(A) * sigma
        A_len = A_mask.sum(1)

        # create weighting by normalizing over the norms of the vectors
        A_weight = A.norm(dim=-1)
        normalization_value = (
            (A_weight * A_mask).sum(1) + 0.00001).unsqueeze(1)
        A_weight = A_weight / normalization_value
    else:
        raise ValueError("Unknown weighting scheme " + str(weighting))

    return A_means, A_sigma, A_weight


def get_local_gmm_divergence(approximation, weighting, sigma, divergence_f=gmm_kl):
    """
    'weighting' got to be one of ['uniform', 'model', 'magnitude']
    """

    def local_gmm(X, Y, mask_X, mask_Y, info):

        X_means, X_sigma, X_weights = prep_gmm_input(
            X, mask_X, weighting, sigma)
        Y_means, Y_sigma, Y_weights = prep_gmm_input(
            Y, mask_Y, weighting, sigma)

        result = divergence_f(X_means, Y_means, X_sigma, Y_sigma,
                              X_weights, Y_weights, approximation)
        return result, None

    return local_gmm
