import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Distribution(nn.Module):
    def __init__(self, input_size, hidden_size, nslot, dropout, process='softmax', sample=False):
        super(Distribution, self).__init__()

        assert process in ['stickbreaking', 'softmax']

        self.mlp = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 1),
        )

        self.slot_emb = nn.Parameter(torch.zeros(nslot, input_size))

        self.hidden_size = hidden_size
        self.process_name = process
        self.sample = sample
        self.nslot = nslot

    def init_p(self, bsz, nslot):
        weight = next(self.parameters()).data
        if self.process_name == 'softmax':
            return weight.new(bsz, nslot).zero_()
        elif self.process_name == 'stickbreaking':
            return weight.new(bsz, nslot + 1).zero_()

    @staticmethod
    def process_stickbreaking(beta, mask, nslot):
        beta_masked = (beta * mask).narrow(-1, 0, nslot-1)
        y = (1 - beta_masked).cumprod(-1)
        p = F.pad(beta_masked, (0, 1), value=1) * F.pad(y, (1, 0), value=1)
        return p

    @staticmethod
    def process_softmax(beta, mask=None):
        beta_normalized = beta - beta.max(dim=-1)[0][:, None]
        x = torch.exp(beta_normalized)
        if mask is not None:
            x = x * mask

        p = F.normalize(x, p=1, dim=1)
        return p

    def forward(self, input, prev_p, ctrl_idx):
        batch_size = input.size(0)

        prev_cp = torch.cumsum(prev_p, dim=1)
        mask = prev_cp[:, 1:]
        mask = F.pad(mask, (0, 1), value=1)

        beta = self.mlp(input + self.slot_emb[None, :, :]).squeeze(-1)
        if self.process_name == 'stickbreaking':
            beta = torch.sigmoid(beta)
            p_predicted = self.process_stickbreaking(beta, mask, self.nslot)
        elif self.process_name == 'softmax':
            beta = beta / math.sqrt(self.hidden_size)
            p_predicted = self.process_softmax(beta, mask)

        if ctrl_idx is None:
            if self.sample and self.training:
                ctrl_idx = torch.multinomial(p_predicted, 1).squeeze(-1)
            else:
                ctrl_idx = p_predicted.max(dim=-1)[1]

        p = torch.zeros_like(prev_p)
        p[torch.arange(batch_size), ctrl_idx] = 1

        # if ctrl_idx is None:
        #     p = (p - p_predicted).detach() + p_predicted

        cp = torch.cumsum(p, dim=1)
        rcp = torch.cumsum(p.flip([1]), dim=1).flip([1])

        return p, cp, rcp, p_predicted