# -*- coding: utf-8 -*-

import torch
import torch.nn as nn

from ..common.dataclass_options import OptionsBase, argfield
from ..common.utils import DotDict
from ..torch_extra.utils import simple_collate, simple_decollate
from .beam import Hypothesis
from .decoder_cells import DecoderCellOptions


class SequenceDecoder(nn.Module):

    class Options(OptionsBase):
        max_decode_steps: int = argfield(256, active_time='both')
        min_decode_steps: int = argfield(1, active_time='both')

        eval_mode: str = argfield('greedy', choices=['beam', 'greedy'],
                                  predict_default='beam', active_time='both')
        beam_size: int = argfield(5, active_time='both')

        teacher_forcing_schedule: str = argfield('none', choices=['none', 'linear', 'constant'])
        teacher_forcing: float = 0.5

        cell: DecoderCellOptions

        def create(self, target_vocab, input_size):
            return SequenceDecoder(self, target_vocab, input_size)

    def __init__(self, options, target_vocab, input_size):
        super().__init__()

        self.options = options
        self.num_words = len(target_vocab)

        self.unk_id = target_vocab.unk_id
        self.sos_id = target_vocab.sos_id
        self.eos_id = target_vocab.eos_id

        self.cell = options.cell.create(input_size, target_vocab)

        schedule = self.teacher_forcing_schedule = options.teacher_forcing_schedule
        self.teacher_forcing = None
        if schedule == 'constant':
            if options.teacher_forcing < 1:
                self.teacher_forcing = options.teacher_forcing
        elif schedule == 'linear':
            self.teacher_forcing = 1

    def after_epoch_hook(self, step, max_steps, _, logger):
        if self.teacher_forcing_schedule == 'linear':
            self.teacher_forcing = 1 - step / max_steps
            logger.info('Teacher forcing reduce to %.4f', self.teacher_forcing)

    def restore_oov_words_to_unk(self, word_ids):
        # NOTE restore temporary OOV words to unk_id
        word_ids[word_ids >= self.num_words] = self.unk_id

    def _forward_train(self, inputs, state_t, extra_inputs):
        step_losses = []
        vocab_dists = []

        # When training, we use words in gold sentence to predicate. (teacher
        # forcing)
        # So input_sequence: Tokens.SOS w1 w2 ... wn
        #    output_sequence: w1 w2 ... wn Tokens.EOS
        # List of: sentence_size * [batch_size]

        decoder_inputs = torch.unbind(inputs.decoder_inputs, dim=1)
        # IGNORE_INDEX should be set to -100
        decoder_targets = torch.unbind(inputs.decoder_targets, dim=1)

        teacher_forcing = self.teacher_forcing
        output_word_t_1 = None
        for gold_output_word_t_1, target_words in zip(decoder_inputs, decoder_targets):
            # shape: [batch_size, word_size]
            if teacher_forcing is None or output_word_t_1 is None:
                output_word_t_1 = gold_output_word_t_1  # use gold standard predications
            else:
                noise = torch.rand(target_words.size(0), device=target_words.device)
                output_word_t_1.masked_scatter_(noise <= teacher_forcing, gold_output_word_t_1)

            step_output = self.cell.forward(state_t, output_word_t_1, *extra_inputs)

            old_coverage = state_t[2]
            # <attn-dist>, <vocab-dist>, <cell-dependent-output>, <cell-state>
            attn_dist_t, vocab_dist_t, _, state_t = step_output
            loss = self.cell.compute_loss(attn_dist_t, vocab_dist_t, old_coverage, target_words)

            output_word_t_1 = vocab_dist_t.argmax(dim=1)

            step_losses.append(loss)
            vocab_dists.append(vocab_dist_t)

        loss = torch.stack(step_losses, 1).sum(dim=1)
        # TODO: average is good ??
        loss = (loss / inputs.decoder_lengths.float()).mean()

        pred_targets = torch.stack(vocab_dists, 1).argmax(dim=2)
        correct = (inputs.decoder_targets == pred_targets).float().sum()
        total = inputs.decoder_lengths.float().sum()

        return DotDict(loss=loss, correct=correct, total=total)

    def _forward_decode(self, inputs, state_t, extra_inputs, return_step_outputs=False):
        output_words_list = []

        if return_step_outputs:
            step_outputs = []

        device = self.cell.encoder_outputs.device
        batch_size = self.cell.encoder_outputs.size(0)

        max_decode_steps = self.options.max_decode_steps

        # shape: [batch_size]
        output_word_t_1 = torch.full([batch_size], self.sos_id, dtype=torch.long, device=device)

        mask = torch.zeros_like(output_word_t_1, dtype=torch.uint8)
        for i in range(max_decode_steps):
            step_output = self.cell.forward(state_t, output_word_t_1, *extra_inputs)

            if return_step_outputs:
                step_outputs.append(step_output)

            # <attn-dist>, <vocab-dist>, <cell-dependent-output>, <cell-state>
            attn_dist_t, vocab_dist_t, _, state_t = step_output
            output_word_t_1 = vocab_dist_t.argmax(dim=1)

            output_words_list.append(output_word_t_1.clone().detach())  # NOTE !!! make a copy

            mask |= (output_word_t_1 == self.eos_id)
            if mask.all():
                break

            self.restore_oov_words_to_unk(output_word_t_1)

        ret = DotDict(targets=torch.stack(output_words_list, dim=1))
        if return_step_outputs:
            ret.step_outputs = simple_collate(step_outputs)
        return ret

    def _forward_decode_beam(self, inputs, encoder_outputs, state_t, extra_inputs):
        device = encoder_outputs.device

        options = self.options
        max_decode_steps = options.max_decode_steps
        min_decode_steps = options.min_decode_steps
        beam_size = options.beam_size

        beams = [Hypothesis([self.sos_id], [0], state_t)]
        results = []

        def _expand(tensor):
            if torch.is_tensor(tensor):
                return tensor.expand(len(beams), *tensor.shape[1:])
            return tensor

        step = 0
        while step < max_decode_steps and len(results) < beam_size:
            state_t = []
            output_word_t_1 = []
            for hypothesis in beams:
                state_t.append(hypothesis.decoder_state)

                output_word_t_1.append(hypothesis.latest_word_id)

            # shape: [current_beam_size]
            output_word_t_1 = torch.tensor(output_word_t_1, device=device)
            self.restore_oov_words_to_unk(output_word_t_1)

            state_t = simple_collate(state_t, mode='cat')

            # expand outputs, labels and mask of encoder
            self.cell.set_encoder_outputs(_expand(encoder_outputs),
                                          _expand(inputs.source_words_ext),
                                          _expand(inputs.source_words_mask))

            _, vocab_dist_t, _, state_t = \
                self.cell.forward(state_t, output_word_t_1, *map(_expand, extra_inputs))

            topk_log_probs, topk_ids = torch.topk(vocab_dist_t.log(), beam_size * 2)
            # Extend each hypothesis and collect them
            new_beams = []
            for i, (hypothesis, decoder_state) in \
                    enumerate(zip(beams, simple_decollate(state_t, mode='split'))):
                for j in range(beam_size * 2):
                    new_beams.append(hypothesis.extend(topk_ids[i, j].item(),
                                                       topk_log_probs[i, j].item(),
                                                       decoder_state))
            new_beams.sort(reverse=True)
            beams.clear()

            for hypothesis in new_beams:
                # If this hypothesis is sufficiently long, put in results.
                # Otherwise discard.
                if hypothesis.latest_word_id == self.eos_id:
                    if step >= min_decode_steps:
                        results.append(hypothesis)
                else:
                    beams.append(hypothesis)

                if len(beams) == beam_size or len(results) == beam_size:
                    break

            step += 1

        if not results:
            results = beams

        results.sort(reverse=True)
        return results

    def forward(self, inputs, encoder_outputs, encoder_hidden, return_step_outputs=False):
        extra_inputs = [inputs.get(key) for key in self.cell.extra_keys]
        state_t = self.cell.get_init_state(encoder_outputs, encoder_hidden)

        self.cell.set_encoder_outputs(encoder_outputs,
                                      inputs.source_words_ext, inputs.source_words_mask)

        if self.training:
            return self._forward_train(inputs, state_t, extra_inputs)
        elif self.options.eval_mode == 'greedy':
            return self._forward_decode(inputs, state_t, extra_inputs,
                                        return_step_outputs=return_step_outputs)
        else:
            assert encoder_outputs.size(0) == 1, 'batch_size should be 1 during beam search'
            return self._forward_decode_beam(inputs, encoder_outputs, state_t, extra_inputs)
