import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils import batched_index_select
from transformers import BertTokenizer, BertModel
import sys

class GRU(nn.Module):
    def __init__(self, e_dim, h_dim, n_layers, dropout, cell="gru"):
        super().__init__()
        self.h_dim = h_dim
        self.dropout = nn.Dropout(dropout)


        if cell == "gru":
            if n_layers == 1:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, bidirectional=True)
            else:
                self.rnn = nn.GRU(e_dim, h_dim, n_layers, dropout=dropout, bidirectional=True)

    def forward(self, _input, text_length, hidden=None):
        #print(self.emb(text.transpose(0, 1)))
        _in = _input.transpose(0, 1) # [sent_length, batch_size, emb_size]
        packed_in = nn.utils.rnn.pack_padded_sequence(_in, text_length)
        packed_output, hidden = self.rnn(packed_in)

        # outputs: [sent_length, batch_size, hidden_size]
        # hidden: [2, batch_size, hidden_size]

        padded, _ = nn.utils.rnn.pad_packed_sequence(packed_output)
        padded = padded[:,:,:self.h_dim] + padded[:,:,self.h_dim:]
        # avg_out = padded.sum(dim=0).div(text_length.float().unsqueeze(dim=1))
        
        return padded, hidden[0, :, :] + hidden[1, :, :]

class BERT(nn.Module):
    def __init__(self, opt, tokenizer):
        super().__init__()
        self.enc = BertModel.from_pretrained('bert-base-chinese')
        #self.enc.resize_token_embeddings(len(tokenizer))
        self.pad_id = tokenizer.pad_token_id
        h_dim = self.enc.config.to_dict()['hidden_size']
        self.rnn = GRU(h_dim, opt.h_dim, opt.n_layers, opt.dropout, cell="gru")
        self.fc = nn.Sequential(
            nn.Linear(opt.h_dim, opt.h_dim),
            nn.ReLU(),
            nn.Dropout(opt.dropout)
        )
        self.out = nn.Linear(opt.h_dim, opt.n_cls)
        self.dropout = nn.Dropout(opt.dropout)

    def create_mask(self, text):
        mask = (text != self.pad_id)
        return mask

    def forward(self, batch):
        text, text_length = batch.text
        mask = self.create_mask(text)
        with torch.no_grad():
            hidden = self.enc(text, attention_mask=mask)[0] # [batch_size, sent_length, hidden_size]
        padded, _ = self.rnn(hidden, text_length.cpu()) # [sent_length, batch_size, hidden_size]
        padded = padded.permute(1, 0, 2)
        encoded_s = batched_index_select(padded.contiguous(), batch.pos_s)
        encoded_e = batched_index_select(padded.contiguous(), batch.pos_e)
        representation = self.fc(encoded_s + encoded_e)
        _out = self.out(representation)
        return _out, representation
