# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..decoder_cell_base import DecoderCellBase


class SimpleCell(DecoderCellBase):

    class Options(DecoderCellBase.Options):
        pass

    def __init__(self, options: Options, input_size, target_vocab):
        super().__init__(options, input_size, target_vocab)

        hidden_size = options.hidden_size
        attention_size = options.attention_size

        self.rnn = self.rnn_cell(attention_size, hidden_size)

        # Decoder #############################################################
        self.input_linear = nn.Linear(self.word_size + input_size, attention_size)
        self.output_linear = nn.Linear(hidden_size + input_size, hidden_size)
        self.pred_linear = nn.Linear(hidden_size, self.num_words)

    def forward(self, decoder_state, output_word_t_1, **_kwargs):
        """
        state_t_1: Tuple of [batch_size, input_size]
        context_t_1: [batch_size, input_size]
        coverage_t_1: [batch_size, max_source_seq_len]
        output_word_t_1: [batch_size]
        """
        assert self.encoder_outputs is not None, 'encoder_outputs is not set'

        state_t_1, context_t_1, coverage_t_1 = decoder_state
        # Concatenate last_word and context
        # shape: [batch_size, attention_size]
        input_t = torch.cat([self.word_embedding(output_word_t_1), context_t_1], dim=1)
        input_t = self.input_linear(input_t)
        if self.input_dropout is not None:
            input_t = self.input_dropout(input_t)
        state_t = self.rnn(input_t, state_t_1)  # h1, c1

        # shape of attn_dist_t: [batch_size, max_source_seq_len]
        attn_dist_t, coverage_t = \
            self.attention(state_t, self.encoder_features, self.source_words_mask, coverage_t_1)
        # shape: [batch_size, max_source_seq_len, input_size]
        context_t = attn_dist_t.unsqueeze(-1) * self.encoder_outputs
        # shape: [batch_size, input_size]
        context_t = context_t.sum(dim=1)  # next context vector

        states = state_t if isinstance(state_t, tuple) else (state_t,)

        # $$P_{vocab} = \mathrm{softmax}(V'(V[s_t;h*_t] + b) + b')$$
        output_t = torch.cat([states[0], context_t], dim=1)
        # shape: [batch_size, attention_size]
        output_t = self.output_linear(output_t)
        # shape: [batch_size, vocab_size]
        vocab_dist_t = F.softmax(self.pred_linear(output_t), dim=1)

        return attn_dist_t, vocab_dist_t, None, (state_t, context_t, coverage_t)
