import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.global_attention import GlobalAttention
from models.stacked_rnn import StackedGRU, StackedLSTM

def rnn_factory(rnn_type, **kwargs):
    """ rnn factory, Use pytorch version when available. """
    no_pack_padded_seq = False
    if rnn_type == "SRU":
        # SRU doesn't support PackedSequence.
        no_pack_padded_seq = True
    else:
        rnn = getattr(nn, rnn_type)(**kwargs)
    return rnn, no_pack_padded_seq

class Decoder(nn.Module):

    def __init__(self, model_opt, vocab_size,
                attn_type="general", attn_func="softmax",
                 coverage_attn=False, num_layers=1,
                 reuse_copy_attn=False, copy_attn_type="mlp",
                 attentional=True, bidirectional_encoder=True):
        super(Decoder, self).__init__()

        self.attentional = attentional
        self.bidirectional_encoder = bidirectional_encoder
        self.hidden_size = int(model_opt["dec_rnn_size"])
        self.citation_dim = int(model_opt["node_dim"])*2
        self.embeddings = nn.Embedding(vocab_size, int(model_opt["word_dim"]), padding_idx=0)
        dropout = float(model_opt["dec_dropout"])
        self.dropout = nn.Dropout(dropout)
        self._input_size = self.embeddings.embedding_dim + self.hidden_size
        # Decoder state
        self.state = {}

        self.rnn = self._build_rnn("GRU",
                                   input_size=self._input_size,
                                   hidden_size=self.hidden_size,
                                   num_layers=num_layers,
                                   dropout=dropout)

        self.linear_out = nn.Linear(self.hidden_size*2, self.hidden_size, bias=True)
        self.citation_linear = nn.Linear(self.citation_dim, self.hidden_size)
        self.fusion_linear_1 = nn.Linear(self.hidden_size*3, self.hidden_size)
        self.fusion_linear_2 = nn.Linear(self.hidden_size, 3)
        # self.linear_out = nn.Linear(self.hidden_size*2 + self.citation_dim, self.hidden_size, bias=True)

        # Set up the standard attention.
        self._coverage = coverage_attn
        if not self.attentional:
            if self._coverage:
                raise ValueError("Cannot use coverage term with no attention.")
            self.attn = None
        else:
            self.context_attn = GlobalAttention(
                self.hidden_size, coverage=coverage_attn,
                attn_type=attn_type, attn_func=attn_func
            )
            self.tgt_abs_attn = GlobalAttention(
                self.hidden_size, coverage=coverage_attn,
                attn_type=attn_type, attn_func=attn_func
            )
            self.attn_linear_out = nn.Linear(self.hidden_size * 3, self.hidden_size, bias=True)
        if model_opt["copy_attn"] == "True" and not reuse_copy_attn:
            if copy_attn_type == "none" or copy_attn_type is None:
                raise ValueError(
                    "Cannot use copy_attn with copy_attn_type none")
            # self.copy_attn = GlobalAttention(
            #     self.hidden_size, attn_type=copy_attn_type, attn_func=attn_func
            # )
            self.copy_attn = True
            self.copy_abs_attn = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
            self.copy_context_attn = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
            self.copy_hidden_attn = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
            self.copy_input_attn = nn.Linear(int(model_opt["word_dim"]), self.hidden_size, bias=False)
            self.copy_citation = nn.Linear(self.citation_dim, self.hidden_size, bias=True)
        else:
            self.copy_attn = None

        self._reuse_copy_attn = reuse_copy_attn and model_opt["copy_attn"] == 'True'
        if self._reuse_copy_attn and not self.attentional:
            raise ValueError("Cannot reuse copy attention with no attention.")

    def _build_rnn(self, rnn_type, input_size,
                   hidden_size, num_layers, dropout):
        assert rnn_type != "SRU", "SRU doesn't support input feed! " \
            "Please set -input_feed 0!"
        if rnn_type == "GRU":
            stacked_cell = StackedGRU
        elif rnn_type == "LSTM":
            stacked_cell = StackedLSTM
        return stacked_cell(num_layers, input_size, hidden_size, dropout)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.

        :param embeddings: pre-trained embeddings
        """
        print("Loaded pretrained embeddings for decoder")
        self.embeddings.weight = nn.Parameter(embeddings)
    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).

        :param fine_tune: Allow?
        """
        for p in self.embeddings.parameters():
            p.requires_grad = fine_tune

    def init_state(self, src, memory_bank, encoder_final):
        """Initialize decoder state with last state of the encoder."""
        def _fix_enc_hidden(hidden):
            # The encoder hidden is  (layers*directions) x batch x dim.
            # We need to convert it to layers x batch x (directions*dim).
            if self.bidirectional_encoder:
                hidden = torch.cat([hidden[0:hidden.size(0):2],
                                    hidden[1:hidden.size(0):2]], 2)
            return hidden

        if isinstance(encoder_final, tuple):  # LSTM
            self.state["hidden"] = tuple(_fix_enc_hidden(enc_hid)
                                         for enc_hid in encoder_final)
        else:  # GRU
            self.state["hidden"] = (_fix_enc_hidden(encoder_final), )

        # Init the input feed.
        batch_size = self.state["hidden"][0].size(1)
        h_size = (batch_size, self.hidden_size)
        self.state["input_feed"] = \
            self.state["hidden"][0].data.new_zeros(*h_size).unsqueeze(0) # (1, batch_size, hidden_size)
        self.state["coverage"] = None

    def map_state(self, fn):
        self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
        self.state["input_feed"] = fn(self.state["input_feed"], 1)

    def detach_state(self):
        self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
        self.state["input_feed"] = self.state["input_feed"].detach()

    def forward(self, tgt, context_memory_bank, tgt_abs_memory_bank, context_memory_lengths,
                tgt_abs_memory_lengths, tgt_abs_salience, tgt_abs_end_idx, citation_output):
        """
        Args:
            tgt (LongTensor): sequences of padded tokens
                 ``(tgt_len, batch, nfeats)``.
            context_memory_bank (FloatTensor): vectors from the context encoder
                 ``(batch, src_len, hidden)``.
            context_memory_lengths (LongTensor): the padded source context lengths
                ``(batch,)``.
            tgt_abs_memory_bank (FloatTensor): vectors from the context encoder
                 ``(batch, src_len, hidden)``.
            tgt_abs_salience (LongTensor): the predicted sentence salience of target abstract
                ``(batch, tgt_abs_sent_num)``.
        Returns:
            (FloatTensor, dict[str, FloatTensor]):
            * dec_outs: output from the decoder (after attn)
              ``(tgt_len, batch, hidden)``.
            * attns: distribution over src at each tgt
              ``(tgt_len, batch, src_len)``.
        """

        dec_state, dec_outs, attns = self._run_forward_pass(tgt, context_memory_bank, tgt_abs_memory_bank,
                                                            context_memory_lengths, tgt_abs_memory_lengths,
                                                            tgt_abs_salience, tgt_abs_end_idx, citation_output)

        # Update the state with the result.
        if not isinstance(dec_state, tuple):
            dec_state = (dec_state,)
        self.state["hidden"] = dec_state
        self.state["input_feed"] = dec_outs[-1].unsqueeze(0)
        self.state["coverage"] = None
        if "coverage" in attns:
            self.state["coverage"] = attns["coverage"][-1].unsqueeze(0)

        # Concatenates sequence of tensors along a new dimension.
        # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
        #       (in particular in case of SRU) it was not raising error in 0.3
        #       since stack(Variable) was allowed.
        #       In 0.4, SRU returns a tensor that shouldn't be stacke
        if type(dec_outs) == list:
            dec_outs = torch.stack(dec_outs)

            for k in attns:
                if type(attns[k]) == list:
                    attns[k] = torch.stack(attns[k])
        return dec_outs, attns

    def _run_forward_pass(self, tgt, context_memory_bank, tgt_abs_memory_bank, context_memory_lengths,
                          tgt_abs_memory_lengths, tgt_abs_salience, tgt_abs_end_idx, citation_output):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        tgt_abs_salience: [batch_size, sent_num]
        """
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        _, _, dim = context_memory_bank.size()
        _, sent_num = tgt_abs_end_idx.size()
        _, citation_dim = citation_output.size()
        assert citation_dim == self.citation_dim
        # END Additional args check.

        dec_outs = []
        attns = {}
        if self.context_attn is not None:
            attns["tgt_std"] = []
            attns["context_std"] = []
        if self.copy_attn is not None or self._reuse_copy_attn:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []
        tgt = tgt.squeeze(2)
        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
            if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        target_l = 1
        for emb_t in emb.split(1):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)
            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            # rnn_output: [batch, dec_hidden]
            # dec_state: tuple(1), dec_state[0] shape: [1, batch, dec_hidden]
            if self.attentional:
                rnn_output = rnn_output.unsqueeze(1)
                context_p_attn, context_c = self.context_attn(
                    rnn_output,
                    context_memory_bank,
                    memory_lengths=context_memory_lengths)
                # context_p_attn: [1, batch_size, src_seq_len], context_c: [batch_size, 1, enc_size*2]
                tgt_abs_p_attn, tgt_abs_c = self.tgt_abs_attn(
                    rnn_output,
                    tgt_abs_memory_bank,
                    memory_lengths=tgt_abs_memory_lengths)
                # tgt_abs_p_attn: [1, batch_size, src_seq_len], tgt_abs_c: [batch_size, 1, enc_size*2]

                tgt_abs_p_attn = torch.mul(tgt_abs_p_attn.squeeze(0), tgt_abs_salience) # [batch_size, src_seq_len]
                tgt_abs_p_attn = F.softmax(tgt_abs_p_attn, -1)

                citation_c = self.citation_linear(torch.tanh(citation_output)).unsqueeze(1)
                concat_c = torch.cat([context_c, tgt_abs_c, citation_c], 2)
                fusion_attn = F.softmax(self.fusion_linear_2(torch.tanh(self.fusion_linear_1(concat_c))), dim=-1)
                concat_c = torch.bmm(fusion_attn, torch.cat([context_c, tgt_abs_c, citation_c], 1))

                concat_c = torch.cat([concat_c, rnn_output], 2).view(tgt_batch * target_l, dim * 2)
                decoder_output = self.linear_out(concat_c).view(tgt_batch, target_l, dim).squeeze(1)

                # concat_c = torch.cat([context_c.squeeze(1), tgt_abs_c.squeeze(1), citation_output],-1)
                # decoder_output = self.linear_out(concat_c)

                attns["tgt_std"].append(tgt_abs_p_attn)
                attns["context_std"].append(context_p_attn.squeeze(0))
            else:
                decoder_output = rnn_output
            # if self.context_gate is not None:
            #     # TODO: context gate should be employed
            #     # instead of second RNN transform.
            #     decoder_output = self.context_gate(
            #         decoder_input, rnn_output, decoder_output
            #     )
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            dec_outs += [decoder_output]
            if self.copy_attn:

                copy_attn = self.copy_abs_attn(tgt_abs_c.squeeze(1)) + self.copy_context_attn(context_c.squeeze(1)) \
                            + self.copy_input_attn(emb_t) + self.copy_hidden_attn(rnn_output.squeeze(1)) \
                            + self.copy_citation(citation_output)
                # copy_attn: [batch_size, dec_size]
                attns["copy"] += [copy_attn]
            # if self._reuse_copy_attn:
            #     attns["tgt_copy"] = attns["tgt_std"]
            #     attns["context_copy"] = attns["context_std"]

        return dec_state, dec_outs, attns