# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...torch_extra.utils import clip_and_renormalize
from ..decoder_cell_base import DecoderCellBase


class PointerGeneratorCell(DecoderCellBase):
    """
    @see https://arxiv.org/pdf/1704.04368.pdf
    Get To The Point: Summarization with Pointer-Generator Networks
    """

    class Options(DecoderCellBase.Options):
        prob_clip_epsilon: float = 1e-8

    def __init__(self, options: Options, input_size, target_vocab):
        """ TODO: doc of params """
        super().__init__(options, input_size, target_vocab, ('extra_zeros',))

        hidden_size = options.hidden_size
        attention_size = options.attention_size

        self.rnn = self.rnn_cell(attention_size, hidden_size)

        # Decoder #############################################################
        self.input_linear = nn.Linear(self.word_size + input_size, attention_size)
        self.output_linear = nn.Linear(hidden_size + input_size, hidden_size)
        self.pred_linear = nn.Linear(hidden_size, self.num_words)

        # Pointer-generator ###################################################
        # Equation (8):
        # $$
        # P_{gen} = \sigma(w^T_{h^*} h^*_t + w^T_s s_t + w^T_x x_t + b_{ptr})
        # $$
        # $h_t$ is the context vector, $x_t$ is the input

        # context_t + state_h + state_c + input_t
        dim = input_size + attention_size + self.state_size
        self.pointer_gen_linear = nn.Linear(dim, 1)

    def forward(self, decoder_state, output_word_t_1, extra_zeros=None, **_kwargs):
        """
        state_t_1: Tuple of [batch_size, input_size]
        context_t_1: [batch_size, input_size]
        coverage_t_1: [batch_size, max_source_seq_len]
        output_word_t_1: [batch_size]
        extra_zeros: [batch_size, max_extra_zero_size]
        """
        assert self.encoder_outputs is not None, 'encoder_outputs is not set'

        state_t_1, context_t_1, coverage_t_1 = decoder_state
        # Concatenate last_word and context
        # shape: [batch_size, attention_size]
        input_t = torch.cat([self.word_embedding(output_word_t_1), context_t_1], dim=1)
        input_t = self.input_linear(input_t)
        if self.input_dropout is not None:
            input_t = self.input_dropout(input_t)
        state_t = self.rnn(input_t, state_t_1)  # h1, c1

        # shape of attn_dist_t: [batch_size, max_source_seq_len]
        attn_dist_t, coverage_t = \
            self.attention(state_t, self.encoder_features, self.source_words_mask, coverage_t_1)
        # shape: [batch_size, max_source_seq_len, input_size]
        context_t = attn_dist_t.unsqueeze(-1) * self.encoder_outputs
        # shape: [batch_size, input_size]
        context_t = context_t.sum(dim=1)  # next context vector

        states = state_t if isinstance(state_t, tuple) else (state_t,)

        # Equation (4):
        # $$P_{vocab} = \mathrm{softmax}(V'(V[s_t;h*_t] + b) + b')$$
        output_t = torch.cat([states[0], context_t], dim=1)
        # shape: [batch_size, attention_size]
        output_t = self.output_linear(output_t)
        # shape: [batch_size, vocab_size]
        vocab_dist_t = F.softmax(self.pred_linear(output_t), dim=1)

        pointer_gen_t = torch.cat([context_t, *states, input_t], dim=1)
        # shape of $P_{gen}$: [batch_size, 1]
        pointer_gen_t = torch.sigmoid(self.pointer_gen_linear(pointer_gen_t))

        vocab_dist_t = pointer_gen_t * vocab_dist_t
        attn_dist_t = (1 - pointer_gen_t) * attn_dist_t

        if extra_zeros is not None:
            vocab_dist_t = torch.cat([vocab_dist_t, extra_zeros], dim=1)

        # OOV words in a source input are temporarily assgined with an ID,
        # which is above the size of target vocabulary. For the possiblity
        # of generating these OOV words are stored in the trailing extra
        # positions in `vocab_dist_t`. Meanwhile, for words inside the
        # vocabulary, there are extra bonus.

        # NOTE: source and target vocab should be the same
        # vocab_dist_t[i][source_words_ext[i][j]] += attn_dist_t[i][j]
        vocab_dist_t = vocab_dist_t.scatter_add(1, self.source_words_ext, attn_dist_t)

        vocab_dist_t = clip_and_renormalize(vocab_dist_t, self.options.prob_clip_epsilon)

        return attn_dist_t, vocab_dist_t, pointer_gen_t, (state_t, context_t, coverage_t)
