# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..common.dataclass_options import OptionsBase, argfield
from ..torch_extra.utils import clip_and_renormalize


class AdditiveAttention(nn.Module):
    def __init__(self, attention_size, decoder_size, input_size, use_coverage=False):
        super().__init__()

        self.use_coverage = use_coverage
        self.attention_size = attention_size

        # Attention ###########################################################
        # Eqution (11):
        # $$ e^t_i = v^T \tanh(W_h h_i + W_s s_t + w_c c^t_i + b_{attn}) $$
        # $s_t$ is decoder state, $c_t_i$ is the coverage vector
        # For every decoding step, $W_h h_i$ is the same. We can precompute
        # $W_h [h_0;h_1;...;h_{n-1}]$

        self.attn_v = nn.Linear(self.attention_size, 1, bias=False)
        if use_coverage:
            self.attn_w_c = nn.Linear(1, self.attention_size, bias=False)

        self.attn_W_h = nn.Linear(input_size, self.attention_size, bias=False)
        self.attn_W_s = nn.Linear(decoder_size, self.attention_size)

    def forward(self, decoder_state, encoder_features, seq_mask, coverage):
        """
        decoder_state: tuple of [batch_size, decoder_size]
        encoder_features: precompute $W_h [h_0;h_1;...;h_{n-1}]$
                          [batch_size, max_source_seq_len, attention_size]
        seq_mask: [batch_size, max_source_seq_len]
        coverage: [batch_size, max_source_seq_len]
        """
        # shape: [batch_size, attention_size]
        if isinstance(decoder_state, tuple):
            decoder_state = torch.cat(decoder_state, dim=1)
        decoder_features = self.attn_W_s(decoder_state)
        # In each decoding step, for every token, decoder_features are all the same
        # shape: [batch_size, 1, attention_size]
        decoder_features = decoder_features.unsqueeze(1)
        # $W_h h_i + W_s s_t + b_{attn}$
        # shape: [batch_size, max_source_seq_len, attention_size]
        all_features = encoder_features + decoder_features

        if self.use_coverage and coverage is not None:
            # $W_h h_i + W_s s_t + w_c c^t_i + b_{attn}$
            # [batch_size, max_source_seq_len, 1] x [1, attention_size]
            coverage_features = self.attn_w_c(coverage.unsqueeze(-1))
            all_features = all_features + coverage_features

        # $ e^t_i = v^T \tanh(W_h h_i + W_s s_t + w_c c^t_i + b_{attn}) $
        # shape: [batch_size, max_source_seq_len]
        attn_e = self.attn_v(torch.tanh(all_features)).squeeze(-1)
        attn_dist = F.softmax(attn_e, dim=1) * seq_mask
        attn_dist = clip_and_renormalize(attn_dist, 1e-8)  # re-normalize

        if self.use_coverage:
            coverage = coverage + attn_dist

        return attn_dist, coverage


class DecoderCellBase(nn.Module):
    class Options(OptionsBase):
        word_size: int
        attention_size: int = 256
        hidden_size: int = 256

        use_coverage: bool = False
        coverage_loss_weight: float = 1.0

        project_encoder_hidden: bool = True

        input_dropout: float = 0.3

        rnn_type: str = argfield('gru', choices=['gru', 'lstm'])

    def __init__(self, options: Options, input_size, target_vocab, extra_keys=()):
        super().__init__()

        self.extra_keys = extra_keys

        self.options = options

        self.input_size = input_size
        self.num_words = len(target_vocab)
        self.word_size = options.word_size
        self.hidden_size = hidden_size = options.hidden_size

        self.encoder_outputs = None

        if options.rnn_type == 'gru':
            self.rnn_cell = nn.GRUCell
            self.state_size = hidden_size
        else:
            self.rnn_cell = nn.LSTMCell
            self.state_size = hidden_size * 2

        self.word_embedding = nn.Embedding(self.num_words, options.word_size,
                                           padding_idx=target_vocab.pad_id)

        self.attention = AdditiveAttention(options.attention_size, self.state_size, input_size,
                                           options.use_coverage)

        self.project_h = self.project_c = None
        if options.project_encoder_hidden:
            self.project_h = nn.Linear(input_size, hidden_size)
            if self.rnn_cell is nn.LSTMCell:
                self.project_c = nn.Linear(input_size, hidden_size)

        if options.input_dropout != 0:
            self.input_dropout = nn.Dropout(options.input_dropout, inplace=True)
        else:
            self.input_dropout = None

    def set_encoder_outputs(self, encoder_outputs, source_words_ext, source_words_mask):
        """
        encoder_outputs: [batch_size, max_source_seq_len, input_size]
        source_words_mask: [batch_size, max_source_seq_len]
        """
        self.encoder_outputs = encoder_outputs
        # shape: [batch_size, max_source_seq_len, attention_size]
        self.encoder_features = self.attention.attn_W_h(encoder_outputs)
        self.source_words_ext = source_words_ext
        self.source_words_mask = source_words_mask.float()

    def get_init_state(self, encoder_outputs, encoder_hidden):
        device = encoder_hidden.device

        if self.project_h is not None:
            state_h = self.project_h(encoder_hidden)
            if self.project_c is not None:
                state_c = self.project_c(encoder_hidden)
        else:
            tensor = torch.zeros(encoder_hidden.size(0), self.hidden_size, device=device)
            state_h = state_c = tensor

        if self.rnn_cell is nn.GRUCell:
            state_0 = state_h
        else:
            state_0 = state_h, state_c

        context_0 = torch.zeros(encoder_hidden.size(0), self.input_size, device=device)
        coverage_0 = None
        if self.options.use_coverage:
            coverage_0 = torch.zeros(*encoder_outputs.shape[:-1], device=device)

        return state_0, context_0, coverage_0

    def compute_loss(self, attn_dist, vocab_dist, old_coverage, target_words):
        loss = F.nll_loss(vocab_dist.log(), target_words, reduction='none')
        if self.options.use_coverage:
            coverage_loss = torch.min(old_coverage, attn_dist).sum(dim=1)
            loss += coverage_loss * self.options.coverage_loss_weight

        # TODO: masking maybe redundant
        return loss * (target_words != -100).float()
