import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils import *
from transformers import AutoTokenizer, AutoModel


class BiLSTM(nn.Module):
    """ BiLSTM Model."""
    def __init__(self, input_size, hidden_size, num_layers, bi):
        """Bi-LSTM Encoder

        Args:
            input_size: (int) vocab word2vec dim
            hidden_size: (int) hidden size in Bi-LSTM
            num_layers: (int) num_layers in Bi-LSTM
            bi: (boolean) Bi-direction
        """
        super(BiLSTM, self).__init__()

        # init
        ## Bi-LSTM
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bi = bi

        # models
        self.rnn = nn.LSTM(self.input_size,
                           self.hidden_size,
                           num_layers=self.num_layers,
                           batch_first=True,
                           bidirectional=self.bi)

    def forward(self, *input):

        x_emb, x_len, return_type = input  # (batch_size, max_len, word2vec_dim) (batch_size, )

        # BiLSTM
        total_length = x_len.max()

        x_packed = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.cpu(), batch_first=True, enforce_sorted=False)
        out_lstm, hidden = self.rnn(x_packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out_lstm, batch_first=True, total_length=total_length)

        # vector represent
        if return_type == 'mean_pooling' or return_type == 'encode':
            out = out.sum(dim=1).div(x_len.float().unsqueeze(-1)) # (batch, num_directions * hidden_size)
        elif return_type == 'all_return':
            pass
        return out


class Encoder(nn.Module):
    """ Bert/BiLSTM Model."""
    def __init__(self, config, cropus,  n_seen_class, encoder_type, hidden_size, num_layers, bi, freeze_emb, tao_cos):
        """Bi-LSTM Encoder

        Args:
            config: (Dict) configuration dict
            datset: (Object) dataset class
            encoder_type: (str) BERT/LSTM
            hidden_size: (int) hidden size in Bi-LSTM
            num_layers: (int) num_layers in Bi-LSTM
            bi: (boolean) Bi-direction
            freeze_emb: (bool) freeze word embedding matrix or not
            tao_cos: (float) tao for CosT
        """
        super(Encoder, self).__init__()

        # init
        self.encoder_type = encoder_type
        self.hidden_size = hidden_size
        self.freeze_emb = freeze_emb
        self.tao_cos = nn.Parameter(torch.tensor(tao_cos))

        if self.encoder_type == 'lstm':
            self.vocab_size = config['dataset']['vocab']['vocab size']
            self.input_size = config['dataset']['vocab']['word2vec dim']
            self.num_layers = num_layers
            self.bi = bi
            if config['dataset']['pretrain']:
                wv_tensor = cropus.get_wordembedding()
                self.embedding = nn.Embedding.from_pretrained(wv_tensor, freeze=self.freeze_emb, padding_idx=0)
            else:
                self.embedding = nn.Embedding(num_embeddings=self.vocab_size,
                                              embedding_dim=self.input_size,
                                              padding_idx=0)
            self.bilstm = BiLSTM(self.input_size,
                                 self.hidden_size,
                                 self.num_layers,
                                 self.bi)

        elif self.encoder_type == 'bert':
            self.bert = AutoModel.from_pretrained('bert-base-uncased')
            self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

            self.bert_size = self.bert.embeddings.word_embeddings.embedding_dim
            self.hidden_size = hidden_size * 2 if bi else hidden_size
            self.freeze_emb = freeze_emb
            # self.fc1 = nn.Linear(self.bert_size, self.hidden_size)


    def bert_forward(self, *input):
        x, x_len, return_type = input  # (batch_size, max_len, word2vec_dim) (batch_size, )
        batch_input = self.tokenizer(x, return_tensors="pt", padding=True)
        batch_input = {k : v.to(self.bert.device) for k, v in batch_input.items()}
        output = self.bert(**batch_input)
        # output = output.last_hidden_state[:, 0, :]

        if return_type == 'mean_pooling' or return_type == 'encode':
            output = output.last_hidden_state[:, 1:, :].sum(dim=1).div(x_len.float().unsqueeze(-1))
        elif return_type == 'all_return':
            output = output.last_hidden_state[:, 1:, :]
        # output = self.fc1(output)

        return output


    def lstm_forward(self, *input):
        x, x_len, return_type = input  # (batch_size, max_len, word2vec_dim) (batch_size, )

        # Embed
        x_emb = self.embedding(x)

        # BiLSTM
        out = self.bilstm(x_emb, x_len, return_type)

        return out

    def forward(self, *input):
        x, x_len, type = input
        return_type = type
        # if type == 'encode':
            # return_type = 'mean_pooling'
        if self.encoder_type == 'bert':
            # print(input)
            return self.bert_forward(x, x_len, return_type)
        else:
            return self.lstm_forward(x, x_len, return_type)


class Generator(nn.Module):
    """Parameters Generator Network for transfer"""
    def __init__(self, emb_size, d_r, alpha):
        super(Generator, self).__init__()

        self.emb_size = emb_size
        self.alpha = alpha
        self.W3 = nn.Linear(emb_size, emb_size, bias=False)
        # self.W2 = nn.Linear(64, emb_size, bias=False)
        self.W1 = nn.Parameter(torch.rand(emb_size, emb_size))
        self.W2 = nn.Parameter(torch.rand(d_r, emb_size))
        self._reset_parmaters()
        # self.decoder = nn.TransformerDecoderLayer(d_model=emb_size, nhead=4)

    def _reset_parmaters(self):
        for name, param in self.named_parameters():
            torch.nn.init.kaiming_normal_(param)

    def forward(self, *input):
        memory_protos, novel_protos, after_memory_protos, after_novel_protos = input
        torch.autograd.set_detect_anomaly(True)

        after_protos = torch.cat([after_memory_protos, after_novel_protos], 0)
        protos = torch.cat([memory_protos, novel_protos], 0)
        delta = after_protos - protos
        delta = self.W3(delta)
        vv = F.softmax(self.W2.matmul(torch.relu(self.W1.matmul(delta.t()))), dim=-1)

        v = vv.matmul(delta)

        return v, 0



class AdaptNet(nn.Module):
    """Attention Adaptation Network."""
    def __init__(self, emb_size):
        super(AdaptNet, self).__init__()
        self.emb_size = emb_size
        # self._reset_parmaters()

    # def _reset_parmaters(self):
    #     for name, param in self.named_parameters():
    #         torch.nn.init.kaiming_normal_(param)

    def forward(self, *input):
        querys_w, querys_len, memory_protos, novel_protos, after_memory_protos, after_novel_protos, delta, beta = input

        bs, seq_len, d_k = querys_w.size(0), querys_w.size(1), querys_w.size(2)

        m_len = len(memory_protos)
        n_len = len(novel_protos)
        beta= beta * (m_len + n_len) / m_len * 10
        # beta=10

        score, idx = cos_sim(querys_w.reshape(-1, self.emb_size), delta).reshape(bs, seq_len, -1).max(-1)
        attn_score = score * beta
        attn_score.masked_fill_(attn_score == 0, -(1e+9))
        attn = F.softmax(attn_score, dim=-1)
        # print(attn)
        out = querys_w.transpose(1, 2).matmul(attn.unsqueeze(-1)).squeeze(-1)

        return out


class AdaptMetaNet(nn.Module):
    '''packaged Bi-LSTM Model'''
    def __init__(self, config, cropus, n_seen_class, encoder_type, hidden_size, num_layers, dropout, bi, tao_cos, tao_attn, tao_attn_gene, d_r, alpha, freeze_emb):
        super(AdaptMetaNet, self).__init__()

        # init
        self.config = config
        self.encoder_type = encoder_type
        ## embedding
        self.vocab_size = config['dataset']['vocab']['vocab size']
        self.input_size = config['dataset']['vocab']['word2vec dim']
        self.freeze_emb = freeze_emb

        ## BiLSTM
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bi = bi


        self.num_classes = n_seen_class

        self.tao_cos = nn.Parameter(torch.tensor(tao_cos))
        self.tao_attn = nn.Parameter(torch.tensor(tao_attn))
        self.tao_attn_gene = nn.Parameter(torch.tensor(tao_attn_gene))
        self.d_r = d_r
        self.alpha = alpha

        # models
        self.encoder = Encoder(config, cropus, n_seen_class, encoder_type, hidden_size, num_layers, bi, freeze_emb, tao_cos)

        if self.encoder_type == 'lstm':
            self.emb_size = self.hidden_size * 2 if self.bi else self.hidden_size
        elif self.encoder_type == 'bert':
            self.bert_size = self.encoder.bert.embeddings.word_embeddings.embedding_dim
            self.emb_size = self.bert_size

        self.class_proto_list = [nn.Parameter(torch.randn(self.emb_size)) for _ in range(self.num_classes)]
        self.class_protos = nn.ParameterList(self.class_proto_list)

        if self.config['arch_step2']['ablation']['proto_adapt']:
            # self.self_attn_layer = nn.MultiheadAttention(embed_dim=self.emb_size, num_heads=1)
            #
            self_attn_layer = nn.TransformerEncoderLayer(d_model=self.emb_size, nhead=8, dim_feedforward=2048)
            self.self_attn_layer = nn.TransformerEncoder(self_attn_layer, num_layers=1)
        if self.config['arch_step2']['ablation']['feat_adapt']:
            self.generator = Generator(self.emb_size, d_r, alpha)
            self.adaptnet = AdaptNet(self.emb_size)
        #
        self._reset_parameters()

    @torch.no_grad()
    def _reset_parameters(self):
        for i, param in enumerate(self.class_protos):
            nn.init.kaiming_normal_(param.unsqueeze_(0)).squeeze_()

    def forward(self, *input):
        # input = x, x_len, type (batch_size, max_len) (batch_size, )
        type = input[-1]
        if type == 'encode':
            x, x_len = input[:-1]
            out = self.encoder(x, x_len, type)
            # out = F.normalize(out, p=2, dim=-1)
            return out

        elif type == 'transfer':
            novel_protos, memory_y, novel_y = input[:-1]

            device = novel_protos.device
            src_protos = torch.zeros(memory_y.size(0) + novel_y.size(0), self.emb_size).to(device)
            for i, x in enumerate(self.class_protos):
                src_protos[i] = self.class_protos[i]
                # src_protos[i] = F.normalize(self.class_protos[i], dim=-1, p=2)
            memory_protos = src_protos[memory_y]#*length

            protos = torch.cat([memory_protos, novel_protos], 0)
            # protos = F.normalize(protos, dim=-1, p=2)

            if self.config['arch_step2']['ablation']['proto_adapt']:
                # src2, score = self.self_attn_layer(protos.unsqueeze(1), protos.unsqueeze(1), protos.unsqueeze(1))
                src2 = self.self_attn_layer(protos.unsqueeze(1))
                # print(src2 - protos)
                # print(src2)
                src2 = src2.squeeze(1)

                src = protos + src2

                after_protos = src
                after_memory_protos = after_protos[:len(memory_protos)]
                after_novel_protos = after_protos[len(memory_protos):]

                if self.config['arch_step2']['ablation']['feat_adapt']:
                    v, loss_r = self.generator(memory_protos, novel_protos, after_memory_protos, after_novel_protos)
                    return after_protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, loss_r
                else:
                    return after_protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, 0, 0
            else:
                src = protos
                after_protos = src
                after_memory_protos = after_protos[:len(memory_protos)]
                after_novel_protos = after_protos[len(memory_protos):]

                if self.config['arch_step2']['ablation']['feat_adapt']:
                    v, loss_r = self.generator(memory_protos, novel_protos, after_memory_protos, after_novel_protos)
                    return after_protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, loss_r
                return after_protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, 0, 0

        elif type == 'adapt':
            querys_x, querys_len, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v = input[:-1]

            if self.config['arch_step2']['ablation']['feat_adapt']:
                querys_w = self.encoder(querys_x, querys_len, 'all_return')
                beta = cos_mmd(after_memory_protos, after_novel_protos)
                querys = self.adaptnet(querys_w, querys_len, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, beta)
                return querys
            else:
                querys = self.encoder(querys_x, querys_len, 'mean_pooling')
                return querys
