import torch
import torch.nn as nn 
# ---------- Module implementations. ----------

#MultiGenPool(cfg.num_layers, normalized_shape, cfg.hidden_dim, cfg.num_heads, cfg.dropout,
                            # cfg.activation.name, activation_cfg=cfg.activation)
att_cfg = {
    "num_layers": 1,
    "num_heads" : 2,
    "hidden_dim": 768,
    "dropout": 0.025,
    "activation": "gelu"
}

INF = 32752  # infinity expressed in float16, this is large enough s.t. exp(-INF) == 0


class MultiGenPool(nn.Module):
    """
    Apply multiple pooling layers and concatenate the output.

    Idea is that the model will learn to pool different things and generate
    better embeddings out of the sequence.
    """

    def __init__(
            self, n_pools: int, d_input: int, d_attn: int, n_heads: int, dropout_prob: float, activation_name: str):
        super().__init__()
        pools = []
        for _n in range(n_pools):
            pools.append(
                GenPool(d_input, d_attn, n_heads, dropout_prob, activation_name))
        self.pools = nn.ModuleList(pools)

    def forward(self, features, mask):
        feature_stack = []
        for pool in self.pools:
            features = pool(features, mask)
            feature_stack.append(features)
        pooled = torch.cat(feature_stack, dim=-1)
        return pooled


class GenPool(nn.Module):
    """
    Generalized pooling from 'Enhancing Sentence Embedding with Generalized Pooling.'
    """

    def __init__(
            self, d_input: int, d_attn: int, n_heads: int, dropout_prob: float, activation_name: str):
        super().__init__()

        if d_attn == 0:
            d_attn = d_input
        # print(f"atn pooler dim {d_attn} in dim {d_input}")
        assert d_attn % n_heads == 0,\
            f"attention pooling dim {d_attn} not divisible by {n_heads} heads"
        self.d_head = d_attn // n_heads
        self.d_head_output = d_input // n_heads
        self.num_heads = n_heads

        w1_head = torch.zeros(n_heads, d_input, self.d_head)
        b1_head = torch.zeros(n_heads, self.d_head)
        w2_head = torch.zeros(n_heads, self.d_head, self.d_head_output)
        b2_head = torch.zeros(n_heads, self.d_head_output)

        self.genpool_w1_head = nn.Parameter(w1_head, requires_grad=True)
        self.genpool_b1_head = nn.Parameter(b1_head, requires_grad=True)
        self.genpool_w2_head = nn.Parameter(w2_head, requires_grad=True)
        self.genpool_b2_head = nn.Parameter(b2_head, requires_grad=True)

        self.activation = make_activation_module(activation_name)
        self.dropout1 = nn.Dropout(dropout_prob)
        self.dropout2 = nn.Dropout(dropout_prob)
        self.dropout3 = nn.Dropout(dropout_prob)
        self.softmax = nn.Softmax(dim=2)
        self.softmax_temp = 1

        self.genpool_one = nn.Parameter(torch.ones(1), requires_grad=False)

    def extra_repr(self) -> str:
        strs = []
        for p in [self.genpool_w1_head, self.genpool_b1_head,
                  self.genpool_w2_head, self.genpool_b2_head]:
            strs.append(f"pool linear {p.shape}")
        return "\n".join(strs)

    def forward(self, features: torch.FloatTensor, mask: torch.BoolTensor):
        """
        Args:
            features: Input features shape (batch_size, seq_len, feat_dim=
            mask: Input mask shape (batch_size, seq_len)

        Returns:
        """
        # print(f"genpool input {features.shape}")
        _batch_size, seq_len, input_dim = features.shape
        # apply first FCs, one for each head

        # features (batch, seq_len, d_input)
        # weight1 (num_heads, d_input, d_head)
        b1 = torch.matmul(features.unsqueeze(1), self.genpool_w1_head.unsqueeze(0))  # (b,1,t,c)*(1,nhead, din, datt) -> (b, nhead, t, dhead)
        b1 += self.genpool_b1_head.unsqueeze(1).unsqueeze(0) # + bias (1, nhead, 1, dhead)
        # output (batch, num_heads, seq_len, d_head)

        # dropout + activation
        # apply nonlinear activation
        b1 = self.activation(self.dropout1(b1))

        # apply second FCs, one for each head
        # weight2 (num_heads, d_head, d_head_output)
        b1 = torch.matmul(b1, self.genpool_w2_head.unsqueeze(0))
        b1 += self.genpool_b2_head.unsqueeze(1).unsqueeze(0)
        # output (batch, num_heads, seq_len, d_head_output)

        # dropout
        b1 = self.dropout2(b1)

        # set pre-softmax activations for masked sequence elements to -inf
        # mask shape (batch, seq_len)
        b1.masked_fill_(mask.unsqueeze(1).unsqueeze(-1), -INF)

        # now softmax individually per head over the sequence
        smweights = self.softmax(b1 / self.softmax_temp)
        # shape (batch, num_heads, seq_len, d_head_output)

        # drop attentions
        smweights = self.dropout3(smweights)

        # multiply input features with softmax weights for all heads
        smweights = smweights.transpose(1, 2).reshape(
            -1, seq_len, input_dim)
        # shape (batch, seq_len, input_dim)

        # use the attention weights to pool over the sequence and done
        pooled = (features * smweights).sum(dim=1)

        # return
        return pooled

def make_activation_module(name: str) -> nn.Module:
    """
    Get activation module instance given by name and configuration object.

    Args:
        name:
        cfg: Hyperparameter config

    Returns:
        Activation module.
    """
    if name == 'none':
        return nn.Identity()
    if name == 'relu':
        return nn.ReLU()
    if name == 'gelu':
        return nn.GELU()
    raise ValueError(f"{name} not found")

if __name__ == "__main__":

    inputs = torch.rand((2, 200, 360))
    inputs_mask = torch.zeros((2, 200), dtype=torch.bool)

    atn = MultiGenPool(att_cfg['num_layers'], 360, att_cfg['hidden_dim'], att_cfg['num_heads'], att_cfg['dropout'], \
                        att_cfg['activation'])
    
    out = atn(inputs, inputs_mask)
    print(out.shape) # (batch, inp_dim)