import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from emb2emb.utils import Namespace
from rnn_decoder import RNNDecoder
from torch.nn.parameter import Parameter
import random
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from autoencoder import Encoder

class SumModule(nn.Module):
    def __init__(self, dim = -1):
        super(SumModule,self).__init__()
        self.dim = dim

    def forward(self,x):
        return x.sum(dim = self.dim)

# h_final = \sum_t (exp(tanh(W@h_t)@u) * h_t)
#           / (\sum_k exp(tanh(W@h_t)@u))
class LearnedContextAttentionModule(nn.Module):
    def __init__(self, input_size):
        super(LearnedContextAttentionModule, self).__init__()
        self.input_size = input_size
        self.u = Parameter(torch.ones(input_size, 1)/input_size)
        self.w = Parameter(torch.eye(input_size))
        self.b = Parameter(torch.zeros(input_size))
    
    def forward(self, x):
        # hiddens: (batch_size, T, d)
        # u:       (d, 1)
        bw = self.w.repeat(x.shape[0], 1, 1)
        bu = self.u.repeat(x.shape[0], 1, 1)
        bb = self.b.repeat(x.shape[0], x.shape[1], 1)
        a = F.softmax(torch.bmm(torch.tanh(torch.baddbmm(bb, x, bw)), bu), dim=1)
        # a: (b, T, 1)
        return torch.bmm(torch.transpose(x, 1, 2), a).squeeze(-1)

class BERTEncoder(Encoder):
    def __init__(self, config):
        super(BERTEncoder, self).__init__(config)

        self.teacher_forcing_ratio = config.teacher_forcing_ratio
        self.gaussian_noise_std = config.gaussian_noise_std
        self.teacher_forcing_batchwise = config.teacher_forcing_batchwise

        self.config = config
        self.device = config.device

        self.max_sequence_len = config.max_sequence_len

        # plus 1 to make the 0th word a "padding" one.
        self.vocab_size = config.vocab_size+1

        self.variational = config.variational

        # self.embedding = nn.Embedding(self.vocab_size, config.input_size, padding_idx=0)  # let 0 denote padding
        self.eos_idx = config.eos_idx
        self.sos_idx = config.sos_idx

        self.bert = BertModel.from_pretrained(config.bert_location)
        # option not to finetune bert
        if not config.finetune_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        # restrict to unit sphere?
        self.unit_sphere = config.unit_sphere

        self.embedding = self.bert.embeddings.word_embeddings
        self.input_size = self.bert.config.hidden_size
        self.bert_hidden_size = self.bert.config.hidden_size

        # Future iterations, do fancy things with the hidden size? to go from bert internal to that
        self.hidden_size = self.bert_hidden_size

        self.bottleneck = config.bert_bottleneck

        if self.bottleneck == "LSTM":
            self.bert_bottleneck = nn.LSTM(
                input_size=self.bert_hidden_size,
                hidden_size=self.bert_hidden_size,
                num_layers=1,
                bidirectional=True,
                batch_first=True
            )
        elif self.bottleneck == "sum":
            self.bert_bottleneck = nn.Sequential(
                nn.Linear(self.input_size, self.input_size),
                nn.ReLU(),
                nn.Linear(self.input_size, self.input_size),
                SumModule(dim = 1)
            )
        elif self.bottleneck == "context":
            self.bert_bottleneck = LearnedContextAttentionModule(self.hidden_size)

        if config.variational:
            self.hidden2mean = nn.Linear(self.hidden_size, self.hidden_size)
            self.hidden2logv = nn.Linear(self.hidden_size, self.hidden_size)

    def encode(self, x, lengths, train=False, reparameterize = True):
        bert_embeddings = self.bert(x)[0]
        # (batch, T, dim)
        
        if self.bottleneck == "LSTM":
            bottleneck, (h_n, c_n) = self.bert_bottleneck(bert_embeddings)
            # Mean over all hiddens
            h = h_n.mean(dim=0)
        elif self.bottleneck == "sum":
            h = self.bert_bottleneck(bert_embeddings)
        elif self.bottleneck == "context":
            h = self.bert_bottleneck(bert_embeddings)

        if self.variational:
            mean = self.hidden2mean(h)
            logv = self.hidden2logv(h)
            std = torch.exp(0.5 * logv)
            if reparameterize:
                h = torch.randn(x.shape[0], self.hidden_size, device=self.device) * std + mean
            else:
                h = mean
            
        if self.unit_sphere:
            h = h / h.norm(p = None, dim = -1, keepdim=True)

        # (batch, hidden_size)
        if train and self.variational:
            return h, mean, logv
        else:
            return h
