import torch
from torch import nn
from torch.distributions import Bernoulli
from torch.nn.modules.transformer import TransformerDecoderLayer
from transformers.modeling_bert import gelu_new as gelu_bert
from transformers.modeling_transfo_xl import RelPartialLearnableMultiHeadAttn

from page.const import *
from page.config import ModelConfig


class MultiheadAttentionWeights(nn.Module):
    def __init__(self, **config):
        super().__init__()
        self.config = config

        assert self.hidden_dim % self.num_heads == 0, \
            "Hidden dimension %s is not divisible by the number of heads %s." % (self.hidden_dim, self.num_heads)

        self.linear_q = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.linear_k = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.dim_head = self.hidden_dim // self.num_heads
        self.sqrt_dim = self.dim_head ** 0.5

    def forward(self, query: torch.Tensor, key: torch.Tensor = None, key_ignorance_mask: torch.Tensor = None,
                attention_mask: torch.Tensor = None, relative_pos: torch.Tensor = None,
                head_at_last: bool = True) -> torch.Tensor:
        if key is None:
            key = query

        batched_attention_mask = None
        if attention_mask is not None and attention_mask.dim() == 3:
            batched_attention_mask = attention_mask
            attention_mask = None

        # query: shape [B, S, H]
        # key_value: shape [B, T, H] or [1, T, H]
        # key_ignorance_mask: shape [B, T], True for values to be masked.
        # attention_mask: shape [S, T] or [B, S, T]
        assert query.shape[0] == key.shape[0] or key.shape[0] == 1 or query.shape[0] == 1
        assert key_ignorance_mask is None or (key.shape[:2] == key_ignorance_mask.shape and
                                              key_ignorance_mask.dtype == torch.bool)
        assert attention_mask is None or (query.shape[1] == attention_mask.shape[0] and
                                          key.shape[1] == attention_mask.shape[1] and
                                          attention_mask.dtype == torch.bool)
        assert batched_attention_mask is None or (query.shape[:2] == batched_attention_mask.shape[:2] and
                                                  key.shape[1] == batched_attention_mask.shape[2] and
                                                  batched_attention_mask.dtype == torch.bool)

        query_len = query.shape[1]
        key_len = key.shape[1]
        batch_size = max(key.shape[0], query.shape[0])

        # Project query, key, & value
        query = self.linear_q(query)
        key = self.linear_k(key)

        # Scale query with sqrt(dim)
        query = query / self.sqrt_dim

        # If key / value has shape [1, T, H], expand it.
        if query.shape[0] == 1:
            query = query.expand(batch_size, -1, -1)
        if key.shape[0] == 1:
            key = key.expand(batch_size, -1, -1)

        # Transform query [B, S, N, H/N] -> [B, N, S, H/N] -> [BN, S, H/N].
        query = query.view(batch_size, query_len, self.num_heads, self.dim_head) \
            .transpose(1, 2).flatten(0, 1).contiguous()
        # Transform key [B, T, N, H/N] -> [B, N, H/N, T] -> [BN, H/T, T].
        key = key.view(batch_size, key_len, self.num_heads, self.dim_head) \
            .permute(0, 2, 3, 1).flatten(0, 1).contiguous()

        # Compute attention weights: [BN, S, T] -> [B, N, S, T]
        attention_weights = torch.bmm(query, key).view(batch_size, self.num_heads, query_len, key_len).contiguous()

        # Apply masks (IMPORTANT!!! This should be applied after GELU for output weights)
        if attention_mask is not None:
            # Recap: attention mask has shape [S, T], which can be broadcasted
            attention_weights.masked_fill_(attention_mask, NEG_INF)

        if key_ignorance_mask is not None:
            # Recap: ignorance mask has shape [B, T] -> [B, 1, 1, T] and apply it.
            attention_weights.masked_fill_(key_ignorance_mask.unsqueeze(1).unsqueeze(1), NEG_INF)

        if batched_attention_mask is not None:
            # Recap: batched attention mask has shape [B, S, T] -> [B, 1, S, T] and apply it.
            attention_weights.masked_fill_(batched_attention_mask.unsqueeze(1), NEG_INF)

        if head_at_last:
            # Output will be [B, N, S, T] -> [B, S, T, N]
            return attention_weights.permute(0, 2, 3, 1).contiguous()
        else:
            return attention_weights

    @property
    def hidden_dim(self) -> int:
        return self.config.get('hidden_dim', 768)

    @property
    def num_heads(self) -> int:
        return self.config.get('num_heads', 12)


class MultiheadAttention(nn.Module):
    def __init__(self, **config):
        super().__init__()
        self.attn = MultiheadAttentionWeights(**config)
        self.dropout_attn = nn.Dropout(self.dropout_p)
        self.linear_v = nn.Linear(self.attn.hidden_dim, self.attn.hidden_dim)
        self.linear_out = nn.Linear(self.attn.hidden_dim, self.attn.hidden_dim)

    def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, key_ignorance_mask: torch.Tensor = None,
                attention_mask: torch.Tensor = None, return_weights: bool = False, **kwargs):
        if key_value is None:
            key_value = query

        # Compute attention scores: [B, N, S, T].
        attn_weights = self.attn(query=query, key=key_value, key_ignorance_mask=key_ignorance_mask,
                                 attention_mask=attention_mask, head_at_last=False)

        # Retrive shape arguments
        batch_size, _, query_len, key_len = attn_weights.shape

        # Compute Softmax values. Shape [B, N, S, T] -> [BN, S, T].
        # For numerical stability, replace NaN with -Inf. (NaN occurs when we should ignore all weights.)
        attn = attn_weights.softmax(dim=-1)
        attn = self.dropout_attn(attn)  # Dropout was applied after softmax in the original paper.
        attn = attn.masked_fill(torch.isnan(attn), 0.0).view(-1, query_len, key_len)

        # Pass linear and transpose value matrix: [1 or B, T, N, H/N] -> [1 or B, N, T, H/N].
        value_size = key_value.shape[0]
        value = self.linear_v(key_value) \
            .view(value_size, key_len, self.attn.num_heads, self.attn.dim_head).transpose(1, 2)

        # If value has shape [1, *], expand it.
        if value_size == 1:
            value = value.expand(batch_size, -1, -1, -1)

        # Flatten dim #0 and #1: [B, N, T, H/N] -> [BN, T, H/N].
        value = value.flatten(0, 1).contiguous()

        # Compute output of weighted sum: [BN, S, H/N] -> [B, N, S, H/N] -> [B, S, N, H/N] -> [B, S, H].
        output = torch.bmm(attn, value) \
            .view(batch_size, self.attn.num_heads, query_len, self.attn.dim_head) \
            .transpose(1, 2).flatten(2, 3).contiguous()

        # Map outputs and return. [B, S, H].
        output = self.linear_out(output)

        if return_weights:
            return output, attn_weights.permute(0, 2, 3, 1).contiguous()
        else:
            # Map outputs and return. [B, S, H].
            return output

    @property
    def dropout_p(self):
        return self.attn.config.get('dropout', 0.0)


class WrappedMultiheadAttention(nn.MultiheadAttention):
    def __init__(self, batch_first=True, **config):
        super().__init__(embed_dim=config.get('hidden_dim', 768),
                         num_heads=config.get('num_heads', 12),
                         dropout=config.get('dropout', 0),
                         bias=True, add_bias_kv=False)

        self.config = config
        self.batch_first = batch_first

    def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, key_ignorance_mask: torch.Tensor = None,
                attention_mask: torch.Tensor = None, return_weights: bool = False, **kwargs):
        key = key_value if key_value is not None else query

        if attention_mask is not None:
            # Target attention mask is a bool tensor, but Pytorch implementation requires a float tensor.
            attention_mask = torch.zeros_like(attention_mask, dtype=torch.float) \
                .masked_fill_(attention_mask, NEG_INF)

        if self.batch_first:
            query_batch_second = query.transpose(0, 1)
            key_batch_second = key.transpose(0, 1)
            result = super().forward(query=query_batch_second, key=key_batch_second, value=key_batch_second,
                                     key_padding_mask=key_ignorance_mask,
                                     attn_mask=attention_mask, need_weights=return_weights)
            result = result[0].transpose(0, 1), result[1]
        else:
            result = super().forward(query=query, key=key, value=key, key_padding_mask=key_ignorance_mask,
                                     attn_mask=attention_mask, need_weights=return_weights)

        if return_weights:
            return result
        else:
            return result[0]


class RelativeMultiheadAttention(nn.Module):
    def __init__(self, pre_layernorm=False, **config):
        super().__init__()
        self.config = config

        assert self.hidden_dim % self.num_heads == 0, \
            "Hidden dimension %s is not divisible by the number of heads %s." % (self.hidden_dim, self.num_heads)

        self.dim_head = self.hidden_dim // self.num_heads
        self.scale_factor = self.dim_head ** -0.5
        self.pre_layernorm = pre_layernorm

        self.linear_qkv = nn.Linear(self.hidden_dim, 3 * self.hidden_dim, bias=False)
        self.linear_out = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)

        self.norm = nn.LayerNorm(self.hidden_dim, eps=self.layernorm_eps)
        self.dropout_out = nn.Dropout(self.dropout_p)
        self.dropout_attn = nn.Dropout(self.dropout_p)

        self.bias_key = nn.Parameter(torch.zeros(self.num_heads, self.dim_head), requires_grad=True)
        self.bias_pos = nn.Parameter(torch.zeros(self.num_heads, self.dim_head), requires_grad=True)
        self.linear_pos = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)

    @property
    def hidden_dim(self):
        return self.config.get('hidden_dim', 768)

    @property
    def intermediate_dim(self):
        return self.config.get('intermediate_dim', 3072)

    @property
    def num_heads(self) -> int:
        return self.config.get('num_heads', 12)

    @property
    def dropout_p(self):
        return self.config.get('dropout', 0.0)

    @property
    def layernorm_eps(self) -> float:
        """
        :rtype: float
        :return: Epsilon to avoid zero-division in LayerNorm.
        """
        return self.config.get('layernorm_eps', 1E-12)

    def _rel_shift(self, x):
        # X: QKBN --> BNQK
        batch_sz, num_heads, query_len, key_len = x.shape
        x_padded = torch.cat([torch.zeros(batch_sz, num_heads, query_len, 1, device=x.device, dtype=x.dtype), x],
                             dim=-1)
        x_padded = x_padded.view(batch_sz, num_heads, key_len + 1, query_len)
        return x_padded[:, :, 1:, :].view_as(x)

    def forward(self, query: torch.Tensor, key_ignorance_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
                key_position: torch.Tensor = None, return_weights: bool = False, **kwargs):
        batch_sz, query_sz = query.shape[:2]
        keypos_sz = key_position.shape[0]

        assert query_sz == keypos_sz, "Query and Key-position should have the same length!"

        query_old = query
        if self.pre_layernorm:
            query_old = self.norm(query_old)

        query, key, value = self.linear_qkv(query_old).chunk(chunks=3, dim=-1)
        keypos = self.linear_pos(key_position)

        query = query.view(batch_sz, query_sz, self.num_heads, self.dim_head)  # [B, T, N, H/N]
        key = key.view(batch_sz, query_sz, self.num_heads, self.dim_head)  # [B, T, N, H/N]
        value = value.view(batch_sz, query_sz, self.num_heads, self.dim_head)  # [B, T, N, H/N]
        keypos = keypos.view(keypos_sz, self.num_heads, self.dim_head)  # [T, H, N/H]

        query_key_part = torch.einsum('bind,bjnd->bnij', query + self.bias_key, key)  # [B, N, T, T]
        query_pos_part = torch.einsum('bind,jnd->bnij', query + self.bias_pos, keypos)  # [B, N, T, T]
        query_pos_part = self._rel_shift(query_pos_part)

        attention_weights = (query_key_part + query_pos_part) * self.scale_factor  # [B, N, T, T]

        # Apply masks
        if attention_mask is not None:
            # Recap: attention mask has shape [S=T, T] and broadcast it.
            attention_weights.masked_fill_(attention_mask, NEG_INF)

        if key_ignorance_mask is not None:
            # Recap: ignorance mask has shape [B, T] -> [B, 1, 1, T] and apply it.
            attention_weights.masked_fill_(key_ignorance_mask.unsqueeze(1).unsqueeze(1), NEG_INF)

        # Retrive shape arguments
        batch_size, _, query_len, key_len = attention_weights.shape

        # Compute Softmax values. Shape [B, N, T, T] -> [BN, T, T].
        # For numerical stability, replace NaN with -Inf. (NaN occurs when we should ignore all weights.)
        attn = attention_weights.softmax(dim=-1)
        attn = self.dropout_attn(attn)  # Dropout was applied after softmax in the original paper.
        attn = attn.masked_fill(torch.isnan(attn), 0.0).view(-1, query_len, key_len)

        # Pass linear and transpose value matrix: [B, T, N, H/N] -> [B, N, T, H/N] -> [BN, T, H/N].
        value = value.transpose(1, 2).flatten(0, 1).contiguous()

        # Compute output of weighted sum: [BN, T, H/N] -> [B, N, T, H/N] -> [B, T, N, H/N] -> [B, T, H].
        output = torch.bmm(attn, value) \
            .view(batch_size, self.num_heads, query_len, self.dim_head) \
            .transpose(1, 2).flatten(2, 3).contiguous()

        # Map outputs and return. [B, T, H].
        output = self.dropout_out(self.linear_out(output)) + query_old

        if not self.pre_layernorm:
            output = self.norm(output)

        if return_weights:
            return output, attention_weights.permute(0, 2, 3, 1).contiguous()
        else:
            # Map outputs and return. [B, S, H].
            return output


class WrappedTransformerLayer(TransformerDecoderLayer):
    def __init__(self, config: ModelConfig, batch_first=True):
        super().__init__(d_model=config.hidden_dim,
                         nhead=config.num_decoder_heads,
                         dim_feedforward=config.intermediate_dim,
                         dropout=config.dropout_layer,
                         activation='relu')

        self.batch_first = batch_first

        # Replace activation to GeLU
        self.activation = gelu_bert

        # Set layernorm epsilon
        eps = config.layernorm_eps
        setattr(self.norm1, 'eps', eps)
        setattr(self.norm2, 'eps', eps)
        setattr(self.norm3, 'eps', eps)

    def forward(self, target: torch.Tensor, memory: torch.Tensor, target_attention_mask: torch.Tensor = None,
                target_ignorance_mask: torch.Tensor = None, memory_ignorance_mask: torch.Tensor = None,
                **kwargs) -> torch.Tensor:
        if target_attention_mask is not None:
            # Target attention mask is a bool tensor, but Pytorch implementation requires a float tensor.
            target_attention_mask = torch.zeros_like(target_attention_mask, dtype=torch.float) \
                .masked_fill_(target_attention_mask, NEG_INF)

        if self.batch_first:
            target_batch_second = target.transpose(0, 1)
            memory_batch_second = memory.transpose(0, 1)

            result_batch_second = super().forward(tgt=target_batch_second, memory=memory_batch_second,
                                                  tgt_mask=target_attention_mask,
                                                  tgt_key_padding_mask=target_ignorance_mask,
                                                  memory_key_padding_mask=memory_ignorance_mask)

            return result_batch_second.transpose(0, 1)
        else:
            return super().forward(tgt=target, memory=memory, tgt_mask=target_attention_mask,
                                   tgt_key_padding_mask=target_ignorance_mask,
                                   memory_key_padding_mask=memory_ignorance_mask)


class TransformerLayer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.attn = MultiheadAttention(hidden_dim=config.hidden_dim, num_heads=config.num_decoder_heads,
                                       layernorm_eps=config.layernorm_eps, dropout=config.dropout_attn)
        self.mem = MultiheadAttention(hidden_dim=config.hidden_dim, num_heads=config.num_decoder_heads,
                                      layernorm_eps=config.layernorm_eps, dropout=config.dropout_attn)

        self.dropout_attn = nn.Dropout(config.dropout_layer)
        self.dropout_mem = nn.Dropout(config.dropout_layer)
        self.dropout_expand = nn.Dropout(config.dropout_layer)
        self.dropout_out = nn.Dropout(config.dropout_layer)

        self.lin_expand = nn.Linear(config.hidden_dim, config.intermediate_dim)
        self.lin_collapse = nn.Linear(config.intermediate_dim, config.hidden_dim)

        self.norm_attn = nn.LayerNorm(config.hidden_dim, eps=config.layernorm_eps)
        self.norm_mem = nn.LayerNorm(config.hidden_dim, eps=config.layernorm_eps)
        self.norm_out = nn.LayerNorm(config.hidden_dim, eps=config.layernorm_eps)

    def forward(self, target, target_ignorance_mask=None, target_attention_mask=None,
                memory=None, memory_ignorance_mask=None):
        attented = self.attn(query=target, attention_mask=target_attention_mask,
                             key_ignorance_mask=target_ignorance_mask)
        target = target + self.dropout_attn(attented)
        target = self.norm_attn(target)

        if memory is not None:
            attented = self.mem(query=target, key_value=memory, key_ignorance_mask=memory_ignorance_mask)
            target = target + self.dropout_mem(attented)
            target = self.norm_mem(target)

        output = self.lin_collapse(self.dropout_expand(gelu_bert(self.lin_expand(target))))
        target = target + self.dropout_out(output)
        target = self.norm_out(target)

        return target


class TransformerXLLayer(TransformerLayer):
    def __init__(self, config):
        super().__init__(config)

        # Override attention and memory layers
        self.attn = RelativeMultiheadAttention(hidden_dim=config.hidden_dim, num_heads=config.num_decoder_heads,
                                               layernorm_eps=config.layernorm_eps, dropout=config.dropout_attn)
        self.mem = WrappedMultiheadAttention(hidden_dim=config.hidden_dim, num_heads=config.num_decoder_heads,
                                             layernorm_eps=config.layernorm_eps, dropout=config.dropout_attn)


__all__ = ['MultiheadAttentionWeights', 'MultiheadAttention',
           'WrappedMultiheadAttention', 'RelativeMultiheadAttention',
           'TransformerLayer', 'WrappedTransformerLayer', 'TransformerXLLayer']
