import torch.nn as nn
import torch
import numpy as np
import torch.functional as f
from multiHeadAttention import *
from torch.autograd import Variable


class Encoder(nn.Module):
    def __init__(self, d_model, hidden_size, num_heads, mask=None, dropout=0.1, activation=nn.ReLU()):
        super(Encoder,self).__init__()
        self.d_model = d_model
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.mask = mask
        self.dropout = dropout
        self.activation = activation
        self.attention = MultiHeadAttention(d_model, num_heads, hidden_size)#, mask, dropout)
        self.attention_norm = nn.LayerNorm(d_model)

        feed_forward = [
            nn.Linear(d_model, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d_model),
        ]

        if self.dropout:
            self.attn_droput = nn.Dropout(dropout)
            feed_forward.append(nn.Dropout(dropout))

        self.feed_forward = nn.Sequential(*feed_forward)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # print('attnetion inout size',x.size())
        # attn = self.attention(k,q,v)
        # add_ = attn + k
        # add_dropout = self.attn_droput(add_)
        # add_norm = self.attention_norm(add_dropout)
        # feedForward = self.feed_forward(add_norm)
        # ff_norm = self.norm(feedForward + k)
        # print("XXXXXXX",x[0].size())
        attended = self.attention_norm(self.attn_droput(self.attention(x[0],x[1],x[2])) + x[0])
        # print("attendede",attended.size())
        out=self.norm(self.feed_forward(attended) + x[0])
        # print("outotouto",out.size())
        return out
