import logging

import torch
import torch.nn as nn

from models.embedding_models.bert_embedding_model import BertEmbedModel
from models.embedding_models.word_char_embedding_model import WordCharEmbedModel
from modules.seq2seq_encoders.seq2seq_bilstm_encoder import BiLSTMEncoder
from modules.decoders.seq_decoder import SeqSoftmaxDecoder
from models.ent_models.cnn_ent_model import CNNEntModel

from modules.seq2seq_encoders.seq2seq_gcn_encoder import GCNEncoder
from utils.prune import pruning_with_entities

logger = logging.getLogger(__name__)


class PipelineEntModel(nn.Module):
    """This class utilizes pipeline method to handle
    entity recognition task, firstly detecting entity
    spans then using CNN for entity spans typing
    """
    def __init__(self, cfg, vocab):
        """This function constructs `PipelineEntModel` components and
        sets `PipelineEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
        """

        super().__init__()
        self.vocab = vocab
        self.lstm_layers = cfg.lstm_layers
        self.gcn_layers = cfg.gcn_layers
        self.dp_prune = cfg.dp_prune
        
        # Encoding
        if cfg.embedding_model == 'word_char':
            self.embedding_model = WordCharEmbedModel(cfg, vocab)
        else:
            self.embedding_model = BertEmbedModel(cfg, vocab)

        self.encoder_output_size = self.embedding_model.get_hidden_size()

        if self.lstm_layers > 0:
            self.seq_encoder = BiLSTMEncoder(input_size=self.encoder_output_size,
                                             hidden_size=cfg.lstm_hidden_unit_dims,
                                             num_layers=cfg.lstm_layers,
                                             dropout=cfg.dropout)
            self.encoder_output_size = self.seq_encoder.get_output_dims()
        
        if self.gcn_layers > 0:
            self.gcn_encoder = GCNEncoder(input_size=self.encoder_output_size,
                                          num_layers=cfg.gcn_layers,
                                          gcn_dropout=cfg.gcn_dropout,
                                          adj_dropout=cfg.adj_dropout,
                                          prune=cfg.prune,
                                          min_threshold=cfg.min_threshold,
                                          max_threshold=cfg.max_threshold)
            self.encoder_output_size = self.gcn_encoder.get_output_dims()

        # Entity span detection 
        self.ent_span_decoder = SeqSoftmaxDecoder(hidden_size=self.encoder_output_size,
                                                  label_size=self.vocab.get_vocab_size('entity_span_labels'))
        
        # Entity span type prediction
        self.cnn_ent_model = CNNEntModel(cfg, vocab, self.encoder_output_size)

    def forward(self, batch_inputs):
        """This function propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- results: ent_loss, ent_pred
        """

        results = {}

        batch_seq_entity_span_labels = batch_inputs['entity_span_labels']
        batch_seq_tokens_lens = batch_inputs['tokens_lens']
        batch_seq_tokens_mask = batch_inputs['tokens_mask']

        # BERT encodes word representations, without LSTM
        self.embedding_model(batch_inputs)
        batch_seq_tokens_encoder_repr = batch_inputs['seq_encoder_reprs']

        if self.lstm_layers > 0:
            batch_seq_encoder_repr = self.seq_encoder(batch_seq_tokens_encoder_repr, batch_seq_tokens_lens).contiguous()
        else:
            batch_seq_encoder_repr = batch_seq_tokens_encoder_repr

        '''
        # Add GCNEncoder before NER
        if self.gcn_layers > 0:
            batch_seq_adj= batch_inputs['adj_fw']
            batch_seq_encoder_repr = self.gcn_encoder(batch_seq_tokens_encoder_repr, batch_seq_adj)          
        else:
            batch_seq_encoder_repr = batch_seq_tokens_encoder_repr
        '''

        batch_inputs['seq_encoder_reprs'] = batch_seq_encoder_repr

        # Predict entity tagging sequences, and record their loss 
        entity_span_outputs = self.ent_span_decoder(batch_seq_encoder_repr, batch_seq_tokens_mask,
                                                    batch_seq_entity_span_labels)

        batch_inputs['ent_span_preds'] = entity_span_outputs['predict']
        results['ent_span_loss'] = entity_span_outputs['loss']
        results['sequence_label_preds'] = entity_span_outputs['predict']
        
        
        # Predict entity span type, and obtain {entity span: entity label} dictionary
        ent_model_outputs = self.cnn_ent_model(batch_inputs)
        ent_preds = self.get_ent_preds(batch_inputs, ent_model_outputs)

        results['ent_loss'] = entity_span_outputs['loss'] + ent_model_outputs['ent_loss']
        results['all_ent_preds'] = ent_preds

        # Add GCNEncoder after entity extraction
        if self.gcn_layers > 0:
            if self.dp_prune:
                batch_seq_adj = self.dynamic_pruning(batch_inputs, ent_preds)
            else:
                batch_seq_adj = batch_inputs['adj_fw']
            batch_seq_encoder_repr = self.gcn_encoder(batch_seq_tokens_encoder_repr, batch_seq_adj)
        batch_inputs['seq_encoder_reprs'] = batch_seq_encoder_repr

        return results

    def get_ent_preds(self, batch_inputs, ent_model_outputs):
        """This function gets entity predictions from entity model outputs
        
        Arguments:
            batch_inputs {dict} -- batch input data
            ent_model_outputs {dict} -- entity model outputs
        
        Returns:
            list -- entity predictions
        """

        ent_preds = []
        candi_ent_cnt = 0
        for ents in batch_inputs['all_candi_ents']:
            cur_ents_num = len(ents)
            ent_pred = {}
            for ent, pred in zip(ents, ent_model_outputs['ent_preds'][candi_ent_cnt:candi_ent_cnt + cur_ents_num]):
                ent_pred_label = self.vocab.get_token_from_index(pred.item(), 'span2ent')
                if ent_pred_label != 'None':
                    ent_pred[ent] = ent_pred_label
            ent_preds.append(ent_pred)
            candi_ent_cnt += cur_ents_num

        return ent_preds

    def get_hidden_size(self):
        """This function returns sentence encoder representation tensor size
        
        Returns:
            int -- sequence encoder output size
        """

        return self.encoder_output_size

    def get_ent_span_feature_size(self):
        """This function returns entity span feature size
        
        Returns:
            int -- entity span feature size
        """
        return self.cnn_ent_model.get_ent_span_feature_size()
    
    def dynamic_pruning(self, batch_inputs, ent_preds):
        """This function returns the dynamically pruned structures

        Returns:
            torch.tensor -- dp structures
        """
        batch_ent_span_preds = [list(ent_pred.keys()) for ent_pred in ent_preds]
        batch_seq_adj = batch_inputs['adj_fw']
        batch_size, max_len, _ = batch_seq_adj.size()
        
        for idx in range(batch_size):
            nodes = [i for ent_start, ent_end in batch_ent_span_preds[idx] for i in range(ent_start, ent_end)]
            #print(nodes)
            #print("before:\n", batch_seq_adj[idx])
            if len(nodes) != 0:
                #print("before:\n", batch_seq_adj[idx])
                dp_adj = pruning_with_entities(nodes, batch_seq_adj[idx], batch_inputs['head'][idx], batch_inputs['children'][idx], batch_inputs['tokens_lens'][idx], max_len, prune=1)
                batch_seq_adj[idx] = torch.tensor(dp_adj, dtype=batch_seq_adj.dtype, device=batch_seq_adj.device)
                #print("after:\n", batch_seq_adj[idx])
            #input()
    
        return batch_seq_adj
