import copy
import math
from typing import Optional, Any

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F


class TransformerEncoder(nn.Module):
    r"""TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        # , use_separate_proj_weight=(True if d_k or d_v else False)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class XformerEncoder(nn.Module):

    def __init__(self, cfg):
        super(XformerEncoder, self).__init__()
        encoder_layer = TransformerEncoderLayer(cfg.D_MODEL, cfg.N_HEAD, cfg.D_FF, cfg.XFMR_DROP)
        self.xfmr_encoder = TransformerEncoder(encoder_layer, cfg.N_XFMR_LAYER)
        self.pos_encoder = LearnedPositionEncoding(cfg.D_PE, cfg.D_MODEL, cfg.PE_DROP, cfg.N_PE)

    def forward(self, x, mask):
        x = self.pos_encoder(x)
        x = self.xfmr_encoder(x, src_key_padding_mask=mask)
        return x


class LearnedPositionEncoding(nn.Embedding):
    def __init__(self, d_pe: int, d_model: int, p_drop: float = 0.1, n_pe: int = 100, is_fix: bool = True):
        super().__init__(n_pe, d_pe)
        self.is_add = (d_pe==d_model)
        self.dropout = nn.Dropout(p=p_drop)

        if is_fix:
            nn.init.zeros_(self.weight)
            position = torch.arange(0, n_pe, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_pe, 2).float() * (-math.log(10000.0)/d_pe))
            self.weight.data[:, 0::2] = torch.sin(position * div_term)
            self.weight.data[:, 1::2] = torch.cos(position * div_term)
            self.weight.requires_grad = False

    def forward(self, x):
        weight = self.weight.data.unsqueeze(1)
        if self.is_add:
            x += weight[:x.size(0),:]
        else:
            p_emb = weight[:x.size(0),:].expand(-1, x.size(1), -1)
            x = torch.cat((x, p_emb), dim=-1)
        return self.dropout(x)
