import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from reformer_pytorch import LSHSelfAttention


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        V = args.num_vocab
        D = args.dim_embd
        self.padding_idx = 0

        if args.embedding_pretrained:
            pretrain_embed = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))['pretrain']
            self.embedding = nn.Embedding(V, D).from_pretrained(pretrain_embed, args.freeze)
        else:
            self.embedding = nn.Embedding(V, D, padding_idx=self.padding_idx)

        self.emb_convert = None
        if not args.dim_model == args.dim_embd:
            self.emb_convert = nn.Linear(args.dim_embd, args.dim_model)

        self.postion_embedding = Positional_Encoding(args.dim_embd, args.max_length, args.dropout, args.device)

        self.encoders = nn.ModuleList([
            Encoder(args.dim_model, args.num_head, args.dim_inner, args.dropout, args.n_centroid)
            for _ in range(args.num_encoder)])

        self.num_head = args.num_head
        self.fc1 = nn.Linear(args.dim_model, args.num_class)

    def forward(self, x):
        out = self.embedding(x)
        out = self.postion_embedding(out)

        if not self.emb_convert is None:
            out = self.emb_convert(out)

        mask = x.ne(self.padding_idx)

        for encoder in self.encoders:
            out = encoder(out, mask)
        out = torch.mean(out, 1)
        # out = F.max_pool1d(out, 1).squeeze(2)
        out = self.fc1(out)
        return (out,)


class Encoder(nn.Module):
    def __init__(self, dim_model, num_head, hidden, dropout, n_centroid):
        super(Encoder, self).__init__()
        self.attention = LSHSelfAttention(dim=dim_model, heads=num_head, dropout=dropout, bucket_size=n_centroid, n_hashes=4, n_local_attn_heads=0)
        self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout)

    def forward(self, x, mask=None):
        out = self.attention(x, input_mask=mask)
        out = self.feed_forward(out)
        return out


class Positional_Encoding(nn.Module):
    def __init__(self, embed, pad_size, dropout, device):
        super(Positional_Encoding, self).__init__()
        self.device = device
        self.pe = torch.Tensor([[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)])
        self.pe[:, 0::2] = np.sin(self.pe[:, 0::2])
        self.pe[:, 1::2] = np.cos(self.pe[:, 1::2])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device)
        out = self.dropout(out)
        return out


class Position_wise_Feed_Forward(nn.Module):
    def __init__(self, dim_model, hidden, dropout=0.0):
        super(Position_wise_Feed_Forward, self).__init__()
        self.fc1 = nn.Linear(dim_model, hidden)
        self.fc2 = nn.Linear(hidden, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + x  # 残差连接
        out = self.layer_norm(out)
        return out
