import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import numpy as np

class RNNEncoder(nn.Module):
    def __init__(self, model_opt, vocab_size, type="abs", use_bridge=True):
        super(RNNEncoder, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, int(model_opt["word_dim"]), padding_idx=1)
        self.input_dim = int(model_opt["word_dim"])
        self.type = type
        self.gpu = model_opt["gpu"] == "True"
        self.hidden_size = int(model_opt["enc_rnn_size"])

        # self.rnn = nn.GRU(input_size=self.input_dim, hidden_size=int(model_opt["enc_rnn_size"]), num_layers=1,
        #                   batch_first=True, bidirectional=True)
        self.rnn = nn.LSTM(input_size=self.input_dim, hidden_size=int(model_opt["enc_rnn_size"]), num_layers=1,
                          batch_first=True, bidirectional=True)
        # if self.type == "abs":
        #     self.sent_gru = nn.GRU(input_size=2*int(model_opt["enc_rnn_size"]), hidden_size=int(model_opt["enc_rnn_size"]),
        #                            num_layers=1, batch_first=True, bidirectional=True)

        # self.sent_rnn = nn.GRU(input_size=2 * int(model_opt["enc_rnn_size"]),
        #                        hidden_size=int(model_opt["enc_rnn_size"]),
        #                        num_layers=1, batch_first=True, bidirectional=True)
        self.sent_rnn = nn.LSTM(input_size=2 * int(model_opt["enc_rnn_size"]),
                               hidden_size=int(model_opt["enc_rnn_size"]),
                               num_layers=1, batch_first=True, bidirectional=True)

        self.use_bridge = use_bridge
        if self.use_bridge:
            self._initialize_bridge("GRU",
                                    int(model_opt["enc_rnn_size"]),
                                    1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.

        :param embeddings: pre-trained embeddings
        """
        print("Loaded pretrained embeddings for encoder")
        self.word_embeddings.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).

        :param fine_tune: Allow?
        """
        for p in self.word_embeddings.parameters():
            p.requires_grad = fine_tune

    def forward(self, input_text, input_lens, type="abs", sent_end_idx=None):
        '''
        :param input_text: shape [batch_size, seq_length]
        :param input_lens: shape [batch_size, seq_length]
        :return: encoder_final, the final state of the encoder, shape [2 , batch, enc_rnn_size]
                memory_bank, all hidden states of the encoder, shape [batch, seq_length, 2*enc_rnn_size]
        '''
        input_embedding = self.word_embeddings(input_text)  # [batch, seq_length, word_dim]

        input_lens = np.array(input_lens)
        sort_idx = np.argsort(-input_lens)
        input_lengths = input_lens[sort_idx]
        # input_lengths[0] = 200
        embedded = input_embedding[sort_idx]
        packed = pack_padded_sequence(embedded, input_lengths, batch_first=True)
        del embedded

        outputs, en_hidden = self.rnn(packed)
        del packed

        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        invert_sort_idx = np.argsort(sort_idx)
        if isinstance(en_hidden, tuple):
            en_hidden = (en_hidden[0].transpose(0,1), en_hidden[1].transpose(0,1))
        else:
            en_hidden = en_hidden.transpose(0, 1)
        outputs = outputs[invert_sort_idx]  # [batch, seq_length, 2*enc_rnn_size]
        if isinstance(en_hidden, tuple):
            en_hidden = (en_hidden[0][invert_sort_idx].transpose(0,1),en_hidden[1][invert_sort_idx].transpose(0,1))
        else:
            en_hidden = en_hidden[invert_sort_idx].transpose(0, 1)  # [2 , batch, enc_rnn_size]
        memory_bank = outputs # [batch, seq_length, 2*enc_rnn_size]
        encoder_final = en_hidden

        # sent_emb = None
        # if type == "abs":
        sent_emb = self.build_sentence_layer(memory_bank, sent_end_idx)

        if self.use_bridge:
            encoder_final = self._bridge(encoder_final) # [2 , batch, enc_rnn_size]

        return encoder_final, memory_bank, sent_emb, input_lengths[0]

    def build_sentence_layer(self, memory_bank, src_sents):
        '''
        In this method we define sentence level representation.
        :param memory_bank: shape[batch, seq_length, 2*enc_rnn_size]
        :param encoder_final:
        :param src_sents:
        :return: sentence embeddings
        '''

        # in each case for the current batch, send the last hidden output as the input to the sent_gru layer
        batch_input_list = []
        for output, sent_id in zip(memory_bank, src_sents): # so we have batch_size to be 1
            common_len = len(sent_id)

            output = output.unsqueeze(1) # shape: [seq_length, 1, 2*enc_rnn_size]
            sent_input_list = []

            start_ind_sent_id = 0
            start_ind = sent_id[start_ind_sent_id] - 1
            while (start_ind < output.size()[0]) and (start_ind_sent_id < sent_id.size()[0]) and (sent_id[start_ind_sent_id] != 0):
                sent_input_list.append(output[start_ind])
                # both ids move to the next
                start_ind_sent_id += 1
                if start_ind_sent_id < sent_id.size()[0]:
                    start_ind += sent_id[start_ind_sent_id]
                else:
                    break
            if len(sent_input_list) < common_len:
                # pad with zero
                pad_size = output[0].size()
                if self.gpu == False:
                    zeros = torch.zeros(pad_size, dtype=torch.float32)
                else:
                    zeros = torch.zeros(pad_size, dtype=torch.float32).cuda()
                pad_list = [zeros] * (common_len-len(sent_input_list))
                sent_input_list = sent_input_list + pad_list
            sent_input = torch.cat(sent_input_list,0).unsqueeze(0) # (1, n_sent, 2*enc_rnn_size)
            batch_input_list.append(sent_input)
        del sent_input
        batch_input_list_concat = torch.cat(batch_input_list, 0) # (batch_size, n_sent, 2*enc_rnn_size)
        del batch_input_list, pad_list
        # get the id of sent length:
        sent_output, (_, _) = self.sent_rnn(batch_input_list_concat)

        # sent_mask = src_sents == 0
        # sent_output.masked_fill_(sent_mask[:,:, None],0)
        return sent_output # (batch_size, n_sent, 2*enc_rnn_size)

    def _initialize_bridge(self, rnn_type, hidden_size, num_layers):

        # LSTM has hidden and cell state, other only one
        number_of_states = 2 if rnn_type == "LSTM" else 1
        # Total number of states
        self.total_hidden_dim = hidden_size * num_layers

        # Build a linear layer for each
        self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim,
                                               self.total_hidden_dim,
                                               bias=True)
                                     for _ in range(number_of_states)])

    def _bridge(self, hidden):
        """Forward hidden state through bridge."""

        def bottle_hidden(linear, states):
            """
            Transform from 3D to 2D, apply linear and return initial size
            """
            size = states.size()
            result = linear(states.reshape(-1, self.total_hidden_dim))
            return F.relu(result).view(size)

        if isinstance(hidden, tuple):  # LSTM
            outs = tuple([bottle_hidden(layer, hidden[ix])
                          for ix, layer in enumerate(self.bridge)])
        else:
            outs = bottle_hidden(self.bridge[0], hidden)
        return outs

class Encoder(nn.Module):
    def __init__(self, model_opt, vocab_size):
        super(Encoder, self).__init__()
        # self.context_encoder = RNNEncoder(model_opt, vocab_size, type="context")
        # self.abstract_encoder = RNNEncoder(model_opt, vocab_size, type="abs")
        self.text_encoder = RNNEncoder(model_opt, vocab_size)

    def forward(self, context, tgt_abstract, src_abstract, context_lengths,
                tgt_abstract_lengths, src_abstract_lengths, context_end_idx,
                tgt_abs_end_idx, src_abs_end_idx):

        # context_encoder_final, context_memory_bank, _, context_max_len = self.context_encoder(context,
        #                                                                                       context_lengths)
        context_encoder_final, context_memory_bank, \
        context_sent_emb, _ = self.text_encoder(context, context_lengths, type="abs", sent_end_idx=context_end_idx)

        # tgt_abs_encoder_final, tgt_abs_memory_bank, \
        # tgt_abs_sent_emb, tgt_abs_max_len = self.abstract_encoder(tgt_abstract, tgt_abstract_lengths,
        #                                                           tgt_abs_end_idx)

        tgt_abs_encoder_final, tgt_abs_memory_bank, \
        tgt_abs_sent_emb, tgt_abs_max_len = self.text_encoder(tgt_abstract, tgt_abstract_lengths,
                                                         type="abs", sent_end_idx=tgt_abs_end_idx)

        # src_abs_encoder_final, src_abs_memory_bank, \
        # src_abs_sent_emb, _ = self.abstract_encoder(src_abstract, src_abstract_lengths,
        #                                             src_abs_end_idx)

        src_abs_encoder_final, src_abs_memory_bank, \
        src_abs_sent_emb, _ = self.text_encoder(src_abstract, src_abstract_lengths,
                                           type="abs", sent_end_idx=src_abs_end_idx)

        return context_encoder_final, context_memory_bank, context_sent_emb, \
               tgt_abs_encoder_final, tgt_abs_memory_bank, tgt_abs_sent_emb, \
               src_abs_encoder_final, src_abs_memory_bank, src_abs_sent_emb, tgt_abs_max_len
