from torch import nn
import torch
import math

from torch_utils import make_mask


class DotProductAlphaEstimator(nn.Module):
    def __init__(self, token_dim):
        super(DotProductAlphaEstimator, self).__init__()
        self.weights = nn.Parameter(torch.randn(token_dim))
        self.token_dim = token_dim

    def forward(self, X, X_lengths):
        batch_size = X.size(0)
        max_len = X.size(1)

        w = self.weights.unsqueeze(0).unsqueeze(1).expand_as(
            X).unsqueeze(3).view(-1, self.token_dim, 1)
        log_alpha = torch.bmm(X.unsqueeze(2).view(-1, 1, self.token_dim), w)
        log_alpha = log_alpha.view(batch_size, max_len)
        return log_alpha


class ContextualizedAlphaEstimator(nn.Module):
    """
    Relates the token to the average of tokens in the embedding
    """

    def __init__(self, token_dim):
        super(ContextualizedAlphaEstimator, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * token_dim, token_dim), nn.Tanh(), nn.Linear(token_dim, 1))

    def forward(self, X, X_lengths):
        batch_size = X.size(0)
        max_len = X.size(1)

        # compute average
        mask = make_mask(batch_size, max_len, X_lengths)
        averages = X * mask.unsqueeze(2)
        averages = (averages).sum(dim=1)
        X_lengths = X_lengths.unsqueeze(1).expand_as(averages)
        averages = averages / X_lengths
        #averages = X.sum(dim=1)

        # concat average with gates and plug into MLP
        averages = averages.unsqueeze(1).expand_as(X)
        features = torch.cat([averages, X], dim=-1)
        log_alpha = self.mlp(features)
        log_alpha = log_alpha.squeeze(2)
        # print(log_alpha)

        return log_alpha


def compute_g_from_logalpha(log_alpha, epsilon=0.1):
    s = torch.sigmoid(log_alpha) * \
        ((1 + 2 * epsilon) - epsilon)
    g = torch.minimum(torch.ones_like(s, device=log_alpha.device), torch.maximum(
        torch.zeros_like(s, device=log_alpha.device), s))
    return g


class L0Drop(nn.Module):

    def __init__(self, token_dim,
                 epsilon=0.1,
                 temperature=2. / 3.,
                 keep_dropped_vectors=False,
                 learned_dummy=True,
                 alpha_estimator=DotProductAlphaEstimator,
                 target_ratio=0.,
                 target_mse=False,
                 discard_epsilon=0.00001,
                 append_gates=False,
                 dont_discard_vectors=False,
                 dont_apply_gates=False):
        super(L0Drop, self).__init__()
        self.epsilon = epsilon
        self.temperature = temperature
        self.token_dim = token_dim
        self.keep_dropped_vectors = keep_dropped_vectors
        self.dont_discard_vectors = dont_discard_vectors
        self.learned_dummy = learned_dummy
        self.target_ratio = target_ratio
        self.target_mse = target_mse
        self.discard_epsilon = discard_epsilon
        self.append_gates = append_gates
        self.dont_apply_gates = dont_apply_gates

        if self.keep_dropped_vectors:

            if self.learned_dummy:
                self.dummy_value = nn.Parameter(torch.randn((token_dim)))
            else:
                self.dummy_value = torch.randn((token_dim))

        self.alpha_estimator = alpha_estimator(token_dim)

        if self.target_ratio > 0. and self.target_mse:
            self.mseloss = nn.MSELoss(reduction='none')

    def forward(self, X, X_length):
        batch_size = X.size(0)
        max_len = X.size(1)

        # estimate log_alpha_i
        log_alpha = self.alpha_estimator(X, X_length)

        if self.training:
            # sample from binary concrete according to equations 10 and 11 in
            # L0Drop paper
            uniform = torch.rand(batch_size, max_len, device=log_alpha.device)
            s = torch.sigmoid(
                (torch.log(uniform) - torch.log(1 - uniform) + log_alpha) / self.temperature)
            # print(s.size())

            # stretch
            s_bar = s * (1 + 2 * self.epsilon) - self.epsilon

            # compute the gate
            g = torch.minimum(torch.ones_like(s_bar, device=X.device), torch.maximum(
                torch.zeros_like(s_bar, device=X.device), s_bar))
            # print("Train")

        else:
            g = compute_g_from_logalpha(log_alpha, self.epsilon)
            # print(g)
            # print("Eval")

        # print(g[0])
        # apply gates to inputs
        X_orig = X

        if self.dont_apply_gates:
            pass
        else:
            X = X * g.unsqueeze(2)

        # only return open gates
        # TODO: can we vectorize this?
        if self.keep_dropped_vectors:
            # replace dropped vectors with dummy value
            dropped_ones = (g == 0)
            new_X = X.clone()
            new_X[dropped_ones] = self.dummy_value.to(new_X.device)
        elif self.dont_discard_vectors:
            new_X = X.clone()
        else:
            new_X = torch.zeros_like(X)
            for i in range(X.size(0)):
                vector_index = 0
                for j in range(X.size(1)):
                    if g[i, j] > self.discard_epsilon:
                        new_X[i, vector_index] = X[i, j]
                        vector_index = vector_index + 1

        # mask out the tokens that are invalid anyways
        mask = torch.arange(g.size(1), device=X_length.device)
        mask = mask < X_length.unsqueeze(1)
        open_gates = (g > self.discard_epsilon)
        open_gates = open_gates * mask
        num_open_gates = open_gates.sum(1)

        if self.keep_dropped_vectors or self.dont_discard_vectors:
            # length doesnt change (we just replaced with dummy value)
            new_X_length = X_length.clone()
        else:
            new_X_length = num_open_gates.detach().clone()

        open_gates_loss = torch.sigmoid(
            self.temperature * math.log(self.epsilon / (1 + self.epsilon)) - log_alpha)
        open_gates_loss = 1 - open_gates_loss
        open_gates_loss = open_gates_loss * mask
        open_gates_loss = open_gates_loss.sum(1)

        if self.target_ratio > 0.:
            ratio = open_gates_loss / X_length
            #print(f"Training: {self.training}", ratio)
            target = torch.tensor([self.target_ratio] *
                                  ratio.size(0), device=ratio.device).float()

            if self.target_mse:
                open_gates_loss = self.mseloss(ratio, target)
            else:
                open_gates_loss = torch.maximum(target, ratio)

        # if the length is zero, we have a problem. make minimum length one
        is_empty = new_X_length == 0
        new_X[is_empty, 0] = X_orig[is_empty, 0]
        new_X_length[is_empty] = 1

        max_len = torch.max(new_X_length)

        if self.append_gates:
            pass
        else:
            new_X = new_X[:, :max_len]

        if self.append_gates:
            new_X = torch.cat(
                [new_X, g.unsqueeze(-1), log_alpha.unsqueeze(-1)], dim=-1).contiguous()

        return (new_X, new_X_length), open_gates_loss


if __name__ == "__main__":

    l0 = L0Drop(30, target_ratio=0.1)
    opt = torch.optim.Adam(l0.parameters(), lr=0.001)

    X = torch.randn((8, 5, 30))
    X_lengths = (torch.rand(8) * 5).long()

    loss = float('inf')
    while (loss > 0.):
        l0.train()
        (X_new, X_lengths_new), l = l0(X, X_lengths)
        #print(X_lengths - X_lengths_new)
        # print(loss)
        l_m = l.mean() + X_new.norm()
        loss = l_m.item()
        print("Train:", loss)

        opt.zero_grad()
        l_m.backward()
        # print(l0.dummy_value.grad)
        opt.step()

        l0.eval()
        (X_new_eval, X_lengths_eval), l_eval = l0(X, X_lengths)
        # print(l0.dummy_value)
        print(X_lengths - X_lengths_eval)
        print("Eval:", l_eval.mean().item())
