import torch
import math


def make_mask(bsize, max_lens, lens):
    mask = torch.arange(max_lens, device=lens.device)
    mask = mask.unsqueeze(0).expand(bsize, -1)
    mask = mask < lens.unsqueeze(1)

    # mask : [batch_size, max_lens]
    return mask


class PositionalEncoder(torch.nn.Module):
    def __init__(self, d_model, max_seq_len=1000, mul_by_sqrt=True, learned_embeddings=False):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = \
                    math.sin(pos / (10000 ** ((2 * i) / d_model)))
                if i + 1 < d_model:
                    pe[pos, i + 1] = \
                        math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.mul_by_sqrt = mul_by_sqrt
        self.learned_embeddings = learned_embeddings
        if self.learned_embeddings:
            self.embedding_learner = torch.nn.Linear(2 * d_model, d_model)

    def forward(self, x, t=None):
        # with torch.no_grad():
        if self.mul_by_sqrt:
            x = x * math.sqrt(self.d_model)
        seq_len = x.size(1)

        if t is not None:
            pe = self.pe[:, t].unsqueeze(1)
        else:
            pe = self.pe[:, 1:seq_len + 1]

        pe = pe.expand_as(x)

        if self.learned_embeddings:
            x = self.embedding_learner(torch.cat([x, pe], dim=-1))
        else:
            x = x + pe
        return x


def remove_vectors(rate, encoded):
    emb, lens = encoded
    lens = lens.to(emb.device)

    drop_probs = torch.ones(emb.size(0), emb.size(1), device=emb.device) * \
        (1. - rate)
    g = torch.bernoulli(drop_probs)

    # mask out the tokens that are invalid anyways
    mask = torch.arange(g.size(1), device=lens.device)
    mask = mask < lens.unsqueeze(1)
    open_gates = (g > 0)
    open_gates = open_gates * mask
    num_open_gates = open_gates.sum(1)

    all_zero_gates = (num_open_gates == 0).any()
    if all_zero_gates:
        # make sure at least one gate is open
        g[num_open_gates == 0, 0] = 1.0
        open_gates = (g > 0)
        open_gates = open_gates * mask
        num_open_gates = open_gates.sum(1)
        all_zero_gates = (num_open_gates == 0).any()

    assert not all_zero_gates

    new_emb = torch.zeros_like(emb)
    for i in range(emb.size(0)):
        vector_index = 0
        for j in range(lens[i]):
            if g[i, j] > 0.:
                new_emb[i, vector_index] = emb[i, j]
                vector_index = vector_index + 1

    new_lens = num_open_gates.detach().clone()
    encoded = (new_emb, new_lens)
    return encoded


def add_vectors(rate, encoded):
    emb, lens = encoded
    lens = lens.to(emb.device)

    add_probs = torch.ones(emb.size(0), emb.size(1), device=emb.device) * \
        rate
    g = torch.bernoulli(add_probs)

    # mask out positions beyond length
    mask = torch.arange(g.size(1), device=lens.device)
    mask = mask < lens.unsqueeze(1)
    g = g * mask

    additional_lengths = g.sum(1).long()
    new_lens = lens + additional_lengths

    new_max_length, _ = new_lens.max(0)
    new_max_length = new_max_length.long().item()
    new_emb = torch.zeros(
        (emb.size(0), new_max_length, emb.size(2)), device=emb.device)

    for i in range(emb.size(0)):
        vector_index = 0

        for j in range(lens[i]):

            new_emb[i, vector_index] = emb[i, j]
            vector_index = vector_index + 1

            if g[i, j] > 0.:

                weights = (lens >= j).float().cpu().numpy()
                weights[i] = 0.

                insert_index = choices(
                    list(range(emb.size(0))), weights)
                new_emb[i, vector_index] = emb[insert_index, j]
                vector_index = vector_index + 1

    encoded = (new_emb, new_lens)
    return encoded
