# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...torch_extra.utils import clip_and_renormalize
from ..decoder_cell_base import DecoderCellBase


class CopyNetCell(DecoderCellBase):
    """
    @see https://arxiv.org/abs/1603.06393
    """

    class Options(DecoderCellBase.Options):
        prob_clip_epsilon: float = 1e-8

    def __init__(self, options: Options, input_size, target_vocab):
        super().__init__(options, input_size, target_vocab, ('extra_zeros',))

        hidden_size = options.hidden_size

        self.rnn = self.rnn_cell(input_size * 2 + self.word_size, hidden_size)

        self.copy_linear = nn.Linear(input_size, hidden_size)
        self.output_linear = nn.Linear(hidden_size, self.num_words)

    def get_init_state(self, encoder_outputs, encoder_hidden):
        state = super().get_init_state(encoder_outputs, encoder_hidden)
        selective_read_t_1 = torch.zeros_like(state[1])  # same shape like context_0

        return (*state, selective_read_t_1)

    def forward(self, decoder_state, output_word_t_1, extra_zeros=None):
        """
        state_t_1: Tuple of [batch_size, input_size]
        context_t_1: [batch_size, input_size]
        coverage_t_1: [batch_size, max_source_seq_len]
        selective_read_t_1: [batch_size, input_size]
        output_word_t_1: [batch_size]
        extra_zeros: [batch_size, max_extra_zero_size]
        """
        assert self.encoder_outputs is not None, 'encoder_outputs is not set'

        state_t_1, context_t_1, coverage_t_1, selective_read_t_1 = decoder_state
        # shape: [batch_size, input_size * 2 + word_size]
        input_t = torch.cat([self.word_embedding(output_word_t_1),
                             context_t_1,
                             selective_read_t_1], dim=1)
        if self.input_dropout is not None:
            input_t = self.input_dropout(input_t)
        # 1. rnn
        state_t = self.rnn(input_t, state_t_1)
        state_h = state_t[0] if isinstance(state_t, tuple) else state_t

        # 2. predict next word y_t
        # 2-1) get scores score_g for generation-mode
        score_g = self.output_linear(state_h)

        # 2-2) get scores score_c for copy mode, remove possibility of giving
        # attention to padded values
        # shape: [batch_size, seq_length, hidden_size]
        score_c = torch.tanh(self.copy_linear(self.encoder_outputs))
        # shape: [batch_size, seq_length]
        score_c = score_c.bmm(state_h.unsqueeze(-1)).squeeze(-1)
        score_c = score_c.masked_fill(~self.source_words_mask.byte(), -1e6)

        # 2-3) get softmax-ed probabilities
        probs = F.softmax(torch.cat([score_g, score_c], dim=1), dim=1)
        prob_g, prob_c = probs[:, :self.num_words], probs[:, self.num_words:]
        if extra_zeros is not None:
            # shape: [batch_size, num_words + max_extra_zero_size]
            prob_g = torch.cat([prob_g, extra_zeros], dim=1)

        # 2-4) add prob_c to prob_g
        vocab_dist_t = prob_g.scatter_add(1, self.source_words_ext, prob_c)
        vocab_dist_t = clip_and_renormalize(vocab_dist_t, self.options.prob_clip_epsilon)

        # 3. compute `selective_read` to use for predicting next word
        # shape: [batch_size, seq_length]
        # NOTE: use `output_word_t` instead of `output_word_t_1`
        mask = (self.source_words_ext == vocab_dist_t.argmax(dim=1, keepdim=True)).float()
        # shape: [batch_size, seq_length]
        read_dist_t = F.normalize(mask * prob_c, p=1, dim=1)
        selective_read_t = read_dist_t.unsqueeze(-1) * self.encoder_outputs
        selective_read_t = selective_read_t.sum(dim=1)

        # 4. attention mechanism
        # shape of attn_dist_t: [batch_size, max_source_seq_len]
        attn_dist_t, coverage_t = \
            self.attention(state_t, self.encoder_features, self.source_words_mask, coverage_t_1)
        # shape: [batch_size, max_source_seq_len, input_size]
        context_t = attn_dist_t.unsqueeze(-1) * self.encoder_outputs
        # shape: [batch_size, input_size]
        context_t = context_t.sum(dim=1)  # next context vector

        decoder_state = (state_t, context_t, coverage_t, selective_read_t)
        return attn_dist_t, vocab_dist_t, read_dist_t, decoder_state
