import torch
import torch.nn as nn
import random
import itertools
import math
import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig


class Classifier(torch.nn.Module):
    '''
    General Classifier with Leaky Relu as activation function
    '''

    def __init__(self, input_dim, output_dim):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(nn.Linear(input_dim, input_dim), nn.LeakyReLU(), nn.Linear(input_dim, input_dim),
                                 torch.nn.Dropout(p=0.1, inplace=False),
                                 nn.LeakyReLU(), nn.Linear(input_dim, input_dim),
                                 torch.nn.Dropout(p=0.1, inplace=False),
                                 nn.LeakyReLU(), nn.Linear(input_dim, output_dim),
                                 nn.LogSoftmax())

    def forward(self, input):
        return self.net(input)

    def weights_init(self):
        for module in self.net:
            torch.nn.init.xavier_uniform(module.weight.data)


class EncoderRNN(nn.Module):
    def __init__(self, args, input_size, hidden_size, number_of_layers):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.args = args

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=number_of_layers,
                          dropout=args.dropout)

    def forward(self, input, hidden):
        embedded = self.embedding(input)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(2 * self.args.number_of_layers, self.args.batch_size, self.hidden_size,
                           device=self.args.device)


class FF(nn.Module):

    def __init__(self, dim_input, dim_hidden, dim_output, num_layers,
                 activation='relu', dropout_rate=0, layer_norm=False,
                 residual_connection=False):
        super(FF, self).__init__()
        assert (not residual_connection) or (dim_hidden == dim_input)
        self.residual_connection = residual_connection

        self.stack = nn.ModuleList()
        for l in range(num_layers):
            layer = []

            if layer_norm:
                layer.append(nn.LayerNorm(dim_input if l == 0 else dim_hidden))

            layer.append(nn.Linear(dim_input if l == 0 else dim_hidden,
                                   dim_hidden))
            layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation])
            layer.append(nn.Dropout(dropout_rate))

            self.stack.append(nn.Sequential(*layer))

        self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden,
                             dim_output)

    def forward(self, x):
        for layer in self.stack:
            x = x + layer(x) if self.residual_connection else layer(x)
        return self.out(x)


class EncoderBert(nn.Module):
    def __init__(self, args):
        super(EncoderBert, self).__init__()
        self.args = args
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.args = args
        self.net = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.args.hidden_size), nn.LeakyReLU(),
                                 nn.Linear(self.args.hidden_size, self.args.hidden_size))

    def forward(self, inputs):
        outputs = self.bert(**inputs)
        output = self.net(outputs.last_hidden_state)
        return torch.sum(output, dim=1)


class EncoderRNN(nn.Module):
    def __init__(self, args):
        super(EncoderRNN, self).__init__()
        self.hidden_size = args.hidden_size
        self.args = args
        self.embedding = nn.Embedding(args.number_of_tokens, args.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True, bidirectional=True,
                          num_layers=args.number_of_layers,
                          dropout=args.dropout)

    def forward(self, inputs):
        input_ids = inputs['input_ids']
        encoder_hidden = self.initHidden()
        embedded = self.embedding(input_ids)
        _, hidden = self.gru(embedded, encoder_hidden)
        return torch.sum(hidden, dim=0)

    def initHidden(self):
        return torch.zeros(2 * self.args.number_of_layers, self.args.batch_size, self.hidden_size,
                           device=self.args.device)


class EncoderMLP(nn.Module):
    def __init__(self, args):
        super(EncoderMLP, self).__init__()
        self.args = args
        self.mlp = nn.Sequential(
            nn.Linear(self.args.hidden_size, self.args.hidden_size * 3),
            nn.Tanh(),
            nn.Linear(self.args.hidden_size * 3, self.args.hidden_size),
            nn.Tanh()
        )

    def forward(self, inputs):
        outputs = self.mlp(inputs.float())
        return outputs


encoder_dic = {
    "RNN": EncoderRNN,
    'BERT': EncoderBert,
    'MLP': EncoderMLP
}

if __name__ == '__main__':
    from transformers import BertTokenizer, BertModel
    import torch

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is how is your life going ?"], return_tensors="pt")
    model = BertModel.from_pretrained('bert-base-uncased')
    outputs = model(**inputs)
