import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils import batched_index_select

class Encoder(nn.Module):
    def __init__(self, vocab, e_dim, h_dim, n_layers, dropout, cell="gru"):
        super().__init__()
        self.emb = nn.Embedding(len(vocab), e_dim)
        if vocab.vectors is not None:
            print("pre-trained word embedding loaded")
            self.emb.weight.data.copy_(vocab.vectors)
        self.h_dim = h_dim
        self.dropout = nn.Dropout(dropout)


        if cell == "gru":
            if n_layers == 1:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, bidirectional=True)
            else:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, dropout=dropout, bidirectional=True)

    def forward(self, text, text_length, hidden=None):
        #print(self.emb(text.transpose(0, 1)))
        _in = self.emb(text.transpose(0, 1)) # [sent_length, batch_size, emb_size]
        packed_in = nn.utils.rnn.pack_padded_sequence(_in, text_length)
        packed_output, hidden = self.rnn(packed_in)

        # outputs: [sent_length, batch_size, hidden_size]
        # hidden: [2, batch_size, hidden_size]

        padded, _ = nn.utils.rnn.pad_packed_sequence(packed_output)
        padded = padded[:,:,:self.h_dim] + padded[:,:,self.h_dim:]
        # avg_out = padded.sum(dim=0).div(text_length.float().unsqueeze(dim=1))
        
        return padded, hidden[0, :, :] + hidden[1, :, :]

class BIGRU(nn.Module):
    def __init__(self, opt, vocab, cell="gru"):
        super().__init__()
        self.enc = Encoder(vocab, opt.e_dim, opt.h_dim, opt.n_layers, opt.dropout, cell="gru")
        self.fc = nn.Sequential(
            nn.Linear(opt.h_dim, opt.h_dim),
            nn.ReLU(),
            nn.Dropout(opt.dropout)
        )
        self.out = nn.Linear(opt.h_dim, opt.n_cls)
        self.dropout = nn.Dropout(opt.dropout)

    def forward(self, batch):
        text, text_length = batch.text
        padded, _ = self.enc(text, text_length.cpu()) # [sent_length, batch_size, hidden_size]
        padded = padded.permute(1, 0, 2)
        encoded = batched_index_select(padded.contiguous(), batch.pos)
        representation = self.fc(encoded)
        _out = self.out(representation)
        return _out, representation
