import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        V = args.num_vocab
        D = args.dim_embd
        self.padding_idx = 0

        if args.embedding_pretrained:
            pretrain_embed = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))['pretrain']
            self.embedding = nn.Embedding(V, D).from_pretrained(pretrain_embed, freeze=args.freeze)
        else:
            self.embedding = nn.Embedding(V, D, padding_idx=self.padding_idx)

        self.emb_convert = None
        if not args.dim_model == args.dim_embd:
            self.emb_convert = nn.Linear(args.dim_embd, args.dim_model)

        self.postion_embedding = Positional_Encoding(args.dim_embd, args.max_length, args.dropout, args.device)

        self.encoders = nn.ModuleList([
            Encoder(args.dim_model, args.num_head, args.dim_inner, args.dropout)
            for _ in range(args.num_encoder)])

        self.num_head = args.num_head
        self.fc1 = nn.Linear(args.dim_model, args.num_class)

    def forward(self, x):
        out = self.embedding(x)
        out = self.postion_embedding(out)

        if not self.emb_convert is None:
            out = self.emb_convert(out)

        padding_mask = x.eq(self.padding_idx)
        mask = padding_mask.unsqueeze(1).expand(-1, self.num_head, -1).contiguous().view(-1, x.size(1))
        mask = mask.unsqueeze(1).expand(-1, x.size(1), -1)

        for encoder in self.encoders:
            out = encoder(out, mask)
        out = torch.mean(out, 1)
        # out = F.max_pool1d(out, 1).squeeze(2)
        out = self.fc1(out)
        return (out,)


class Encoder(nn.Module):
    def __init__(self, dim_model, num_head, hidden, dropout):
        super(Encoder, self).__init__()
        self.attention = Multi_Head_Attention(dim_model, num_head, dropout)
        self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout)

    def forward(self, x, mask=None):
        out = self.attention(x, mask)
        out = self.feed_forward(out)
        return out


class Positional_Encoding(nn.Module):
    def __init__(self, embed, pad_size, dropout, device):
        super(Positional_Encoding, self).__init__()
        self.device = device
        self.pe = torch.Tensor([[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)])
        self.pe[:, 0::2] = np.sin(self.pe[:, 0::2])
        self.pe[:, 1::2] = np.cos(self.pe[:, 1::2])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device)
        out = self.dropout(out)
        return out


class Scaled_Dot_Product_Attention(nn.Module):
    '''Scaled Dot-Product Attention '''
    def __init__(self):
        super(Scaled_Dot_Product_Attention, self).__init__()

    def forward(self, Q, K, V, scale=None, mask=None):
        '''
        Args:
            Q: [batch_size, len_Q, dim_Q]
            K: [batch_size, len_K, dim_K]
            V: [batch_size, len_V, dim_V]
            scale: 缩放因子 论文为根号dim_K
        Return:
            self-attention后的张量，以及attention张量
        '''
        attention = torch.matmul(Q, K.permute(0, 2, 1))
        if scale:
            attention = attention * scale
        if mask is not None:
            attention = attention.masked_fill(mask, -1e10)
        attention = F.softmax(attention, dim=-1)
        context = torch.matmul(attention, V)
        return context


class Multi_Head_Attention(nn.Module):
    def __init__(self, dim_model, num_head, dropout=0.0):
        super(Multi_Head_Attention, self).__init__()
        self.num_head = num_head
        assert dim_model % num_head == 0
        self.dim_head = dim_model // self.num_head
        self.fc_Q = nn.Linear(dim_model, num_head * self.dim_head)
        self.fc_K = nn.Linear(dim_model, num_head * self.dim_head)
        self.fc_V = nn.Linear(dim_model, num_head * self.dim_head)
        self.attention = Scaled_Dot_Product_Attention()
        self.fc = nn.Linear(num_head * self.dim_head, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, x, mask=None):
        batch_size = x.size(0)
        Q = self.fc_Q(x).view(batch_size, -1, self.num_head, self.dim_head)
        K = self.fc_K(x).view(batch_size, -1, self.num_head, self.dim_head)
        V = self.fc_V(x).view(batch_size, -1, self.num_head, self.dim_head)
        Q = Q.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_head, -1, self.dim_head)
        K = K.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_head, -1, self.dim_head)
        V = V.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_head, -1, self.dim_head)

        scale = K.size(-1) ** -0.5
        context = self.attention(Q, K, V, scale, mask)

        context = context.view(batch_size, self.num_head, -1, self.dim_head).permute(0, 2, 1, 3)
        context = context.contiguous().view(batch_size, -1, self.num_head*self.dim_head)
        out = self.fc(context)
        out = self.dropout(out)
        out = out + x
        out = self.layer_norm(out)
        return out


class Position_wise_Feed_Forward(nn.Module):
    def __init__(self, dim_model, hidden, dropout=0.0):
        super(Position_wise_Feed_Forward, self).__init__()
        self.fc1 = nn.Linear(dim_model, hidden)
        self.fc2 = nn.Linear(hidden, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + x  # 残差连接
        out = self.layer_norm(out)
        return out
