import torch
import torch.nn as nn
import numpy as np

class Encoder(nn.Module):
    def __init__(self, emb, e_dim, h_dim, n_layers, dropout, cell="gru"):
        super().__init__()
        self.emb = emb
        self.h_dim = h_dim

        if cell == "gru":
            if n_layers == 1:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, bidirectional=True)
            else:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, dropout=dropout, bidirectional=True)

    def forward(self, text, text_length, hidden=None):
        _in = self.emb(text.transpose(0, 1))
        packed_in = nn.utils.rnn.pack_padded_sequence(_in, text_length, enforce_sorted=False)
        packed_output, hidden = self.rnn(packed_in)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)
        # outputs: [sent_length, batch_size, hidden_size]
        # hidden: [1, batch_size, hidden_size]
        return output[:,:,:self.h_dim] + output[:,:,self.h_dim:], hidden.squeeze(0)

class Attention(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.hidden = hidden
        self.attn = nn.Linear(hidden, hidden)
        self.v = nn.Parameter(torch.rand(hidden))

    def forward(self, encoder_outputs):
        # encoder_outputs: [sent_length, batch_size, hidden_size]

        batch_size = encoder_outputs.shape[1]
        sent_len = encoder_outputs.shape[0]

        # repeat entity representations sent_length times
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # entity: [batch_size, sent_length, emb_size]
        # encoder_outputs: [batch_size, sent_length, hidden_size]

        energy = torch.tanh(self.attn(encoder_outputs))
        # energy: [batch_size, sent_length, hidden_size]

        energy = energy.permute(0, 2, 1)
        # energy: [batch_size, hidden_size, sent_length]

        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        # v = [batch_size, 1, hidden_size]

        attention = torch.bmm(v, energy).squeeze(1)
        # attention: [batch_size, sent_length]

        return nn.functional.softmax(attention, dim=1)

class BiATT(nn.Module):
    def __init__(self, opt, vocab, cell="gru"):
        super().__init__()
        self.emb = nn.Embedding(len(vocab), opt.e_dim)
        if vocab.vectors is not None:
            print("pre-trained word embedding loaded")
            self.emb.weight.data.copy_(vocab.vectors)
        self.pre_enc = Encoder(self.emb, opt.e_dim, opt.h_dim, opt.n_layers, opt.dropout, cell="gru")
        self.pos_enc = Encoder(self.emb, opt.e_dim, opt.h_dim, opt.n_layers, opt.dropout, cell="gru")
        self.pre_att = Attention(opt.h_dim)
        self.pos_att = Attention(opt.h_dim)
        self.fc = nn.Sequential(
            nn.Linear(opt.h_dim*2+opt.e_dim, opt.h_dim),
            nn.ReLU(),
            nn.Dropout(opt.dropout)
        )
        self.out = nn.Linear(opt.h_dim, opt.n_cls)

    def forward(self, batch):
        pre, pre_length = batch.pre
        pre_out, _ = self.pre_enc(pre, pre_length.cpu()) # pre_rep: [batch_size, hidden_size]
        pre_att = self.pre_att(pre_out).unsqueeze(1) # [batch_size, 1, sent_len]
        pre_out = pre_out.permute(1, 0, 2)
        # [batch size, sent_length, hidden_size]
        pre_weighted = torch.bmm(pre_att, pre_out).squeeze(1)
        pos, pos_length = batch.pos
        pos_out, _ = self.pos_enc(pos, pos_length.cpu()) # pre_rep: [batch_size, hidden_size]
        pos_att = self.pre_att(pos_out).unsqueeze(1) # [batch_size, 1, sent_len]
        pos_out = pos_out.permute(1, 0, 2)
        # [batch size, sent_length, hidden_size]
        pos_weighted = torch.bmm(pos_att, pos_out).squeeze(1)
        re, _ = batch.re
        re = self.emb(re.transpose(0, 1)).squeeze(0)
        dout = torch.cat((pre_weighted, re, pos_weighted), 1)
        representation = self.fc(dout)
        _out = self.out(representation)
        
        return _out, representation
