import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from config.config import ContextEmb, Config, CharMoel
from model.IntNet import IntNet
from model.aelgcn.aelgcn_base import MultiHeadAttention, MultiGraphConvLayer
from model.charbilstm import CharBiLSTM
from model.aelgcn.aelgcn_base import GraphConvLayer
from model.linear_crf_inferencer import LinearCRF


class NNCRF_AELGCN(nn.Module):
    def __init__(self, config: Config):
        super(NNCRF_AELGCN, self).__init__()
        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.use_char_model = config.use_char_model
        self.context_emb = config.context_emb

        self.label2idx = config.label2idx
        self.labels = config.idx2labels

        self.input_size = config.embedding_dim

        if self.use_char:
            if self.use_char_model == CharMoel.bilstm:
                self.char_feature = CharBiLSTM(config, bidirectional=False)
                self.input_size += config.charlstm_hidden_dim
            else:
                self.char_feature = IntNet(config).to(self.device)
                self.input_size += self.char_feature.last_dim

        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size

        self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(config.word_embedding), freeze=False).to(
            self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)

        self.pos_label_embedding = nn.Embedding(len(config.poslabel2idx), config.pos_emb_size).to(self.device)
        self.input_size += config.pos_emb_size

        self.dep_label_embedding = nn.Embedding(len(config.deplabel2idx), config.dep_emb_size).to(self.device)
        self.input_size += config.dep_emb_size


        self.edge_embeddings = nn.Embedding(len(config.deplabel2idx), config.dep_emb_size, padding_idx=0).to(self.device)

        self.num_lstm_layer = config.num_lstm_layer
        self.lstm_hidden_dim = config.hidden_dim

        # final_hidden_dim = self.lstm_hidden_dim

        self.lstm = nn.LSTM(self.input_size, self.lstm_hidden_dim // 2,
                            num_layers=config.num_lstm_layer, batch_first=True, bidirectional=True).to(self.device)
        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        self.pooling = config.gcn_pool
        self.gcn_layers = nn.ModuleList()
        self.gcn_drop = nn.Dropout(config.gcn_dropout)
        self.gcn_dim = config.gcn_dim
        self.att_gcn_dropout = config.att_gcn_dropout
        self.att_heads = config.att_heads
        self.att_layers = config.att_layers
        self.dep_embed_dim = config.dep_emb_size
        self.num_gcn_layer = config.num_gcn_layer
        self.input_W_G = nn.Linear(self.lstm_hidden_dim, self.gcn_dim).to(self.device)
        for i in range(self.num_gcn_layer):
            self.gcn_layers.append(
                GraphConvLayer(self.device, self.gcn_dim, self.dep_embed_dim, self.pooling)).to(self.device)

        self.attn = MultiHeadAttention(self.att_heads, self.gcn_dim).to(self.device)
        self.attn_layers = nn.ModuleList()
        self.attn_layers.append(MultiGraphConvLayer(self.att_gcn_dropout, self.gcn_dim,  self.att_layers // 2, self.att_heads)).to(self.device)
        self.attn_layers.append(MultiGraphConvLayer(self.att_gcn_dropout, self.gcn_dim,  self.att_layers, self.att_heads)).to(self.device)

        self.aggregate_W = nn.Linear(self.gcn_dim + self.num_gcn_layer * self.gcn_dim, self.gcn_dim).to(self.device)

        self.hidden2tag = nn.Linear(self.gcn_dim, self.label_size).to(self.device)

        self.inferencer = LinearCRF(label_size=config.label_size, label2idx=config.label2idx,
                                    idx2labels=config.idx2labels).to(self.device)

    def neural_scoring(self, word_seq_tensor, word_seq_lens, batch_context_emb, char_inputs, char_seq_lens,
                       pos_label_tensor, edge, adj_matrix, dep_label_tensor):
        batch_size = word_seq_tensor.size(0)
        sent_len = word_seq_tensor.size(1)
        word_emb = self.word_embedding(word_seq_tensor)

        if self.use_char:
            char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lens)
            word_emb = torch.cat((word_emb, char_features), 2)

        if self.context_emb != ContextEmb.none:
            word_emb = torch.cat((word_emb, batch_context_emb.to(self.device)), 2)

        pos_emb = self.pos_label_embedding(pos_label_tensor)
        word_emb = torch.cat((word_emb, pos_emb), 2)

        dep_emb = self.dep_label_embedding(dep_label_tensor)
        word_emb = torch.cat((word_emb, dep_emb), 2)

        word_rep = self.word_drop(word_emb)

        sorted_seq_len, permIdx = word_seq_lens.sort(0, descending=True)
        _, recover_idx = permIdx.sort(0, descending=False)
        sorted_seq_tensor = word_rep[permIdx]

        packed_words = pack_padded_sequence(sorted_seq_tensor, sorted_seq_len.cpu().numpy(), True)
        lstm_out, _ = self.lstm(packed_words, None)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
        feature_out = self.drop_lstm(lstm_out)[recover_idx]
        # outputs = self.hidden2tag(feature_out)

        edge = edge.to(self.device)
        adj_matrix = adj_matrix.to(self.device)
        weight_adj = self.edge_embeddings(edge)

        gcn_inputs = self.input_W_G(feature_out)
        gcn_outputs = gcn_inputs
        layer_list = [gcn_inputs]

        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(
            self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)
        mask = mask.unsqueeze(-2)
        for _layer in range(self.num_gcn_layer):
            gcn_outputs, weight_adj = self.gcn_layers[_layer](weight_adj, gcn_outputs, adj_matrix)  # [batch, seq, dim]
            # if _layer == 0 or _layer == 1:
            if _layer == 0:
                for i in range(2):
                    attn_tensor = self.attn(gcn_outputs, gcn_outputs, mask)
                    attn_adj_list = [attn_adj.squeeze(1) for attn_adj in torch.split(attn_tensor, 1, dim=1)]
                    gcn_outputs = self.attn_layers[i](attn_adj_list, gcn_outputs)
            # remove for catalan
            gcn_outputs = self.gcn_drop(gcn_outputs)
            weight_adj = self.gcn_drop(weight_adj)
            layer_list.append(gcn_outputs)

        outputs = torch.cat(layer_list, dim=-1)
        aggregate_out = self.aggregate_W(outputs)
        outputs = self.hidden2tag(aggregate_out)
        return outputs

    def neg_log_obj(self, words, word_seq_lens, batch_context_emb, chars, char_seq_lens, batch_pos_label, tags, dep_label_adj, adj_matrix, dep_label_tensor):
        features = self.neural_scoring(words, word_seq_lens, batch_context_emb, chars, char_seq_lens, batch_pos_label, dep_label_adj, adj_matrix, dep_label_tensor)
        all_scores = self.inferencer.calculate_all_scores(features)
        batch_size = words.size(0)
        sent_len = words.size(1)
        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(
            self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)

        unlabed_score = self.inferencer.forward_unlabeled(all_scores, word_seq_lens)
        labeled_score = self.inferencer.forward_labeled(all_scores, word_seq_lens, tags, mask)
        return unlabed_score - labeled_score

    def decode(self, batchInput):
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, \
        adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, \
        tagSeqTensor, batch_dep_label, batch_pos_label = batchInput
        features = self.neural_scoring(wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths,
                                       batch_pos_label, dep_label_adj, adj_matrixs, batch_dep_label)
        all_scores = self.inferencer.calculate_all_scores(features)
        bestScores, decodeIdx = self.inferencer.viterbi_decode(all_scores, wordSeqLengths)
        batch_size = wordSeqTensor.size(0)
        for idx in range(batch_size):
            decodeIdx[idx][:wordSeqLengths[idx]] = decodeIdx[idx][:wordSeqLengths[idx]].flip([0])
        return bestScores, decodeIdx
