from pathlib import Path
from typing import Dict
import time

import torch
from torch import nn

from page.const import *
from page.config import ModelConfig
from .equation import *
from .text import TextModel
from page.util import filter_dict_by_keys

BAT = 'bat'  # Baseline Albert-Transformer
PAT = 'pat'  # Pointer Albert-Transformer
OAT = 'oat'  # Operation Albert-Transformer
OPAT = 'opat'  # Operation-pointer Albert-Transformer

SUBMODULE_TYPES = {
    BAT: BaselineAlbertTransformer,
    PAT: PointerAlbertTransformer,
    OAT: OperationAlbertTransformer,
    OPAT: OperationPointerAlbertTransformer
}


class Solver(nn.Module):
    """
    Model that operates TextModel and EquationModels.
    """

    def __init__(self, text_model: TextModel, **kwargs):
        super().__init__()

        # Set sub-model instances.
        self.text_model = text_model
        for key, module in kwargs.items():
            if key in SUBMODULE_TYPES and isinstance(module, SUBMODULE_TYPES.get(key, type(None))):
                self.eqn_model = module
                self.eqn_model_type = key
                break

    @property
    def required_field(self):
        if self.eqn_model_type == OPAT:
            return TUPLE_PTR
        elif self.eqn_model_type == OAT:
            return TUPLE_GEN
        elif self.eqn_model_type == PAT:
            return TOKEN_PTR
        else:
            return TOKEN_GEN

    @property
    def is_tuple_type(self):
        return self.eqn_model_type not in [BAT, PAT]

    def _forward_training(self, **kwargs):
        # Encode the problem texts, update keyword arguments and apply the decoder.
        text = self.text_model(**kwargs)

        # Prepare dictionary to store results
        arguments = filter_dict_by_keys(kwargs, 'var_token', 'num_token', IN_EQN, IN_ENUM, IN_EVAR)
        result = self.eqn_model(**text, **arguments)

        # Remove outputs other than loss & accuracy.
        keys = [key for key in result if key.startswith('Train_')]
        result = filter_dict_by_keys(result, *keys)

        # Sum up losses
        result['total_loss'] = 0

        for key, value in result.items():
            if key.endswith('/loss'):
                result['total_loss'] += value

        return result

    def _forward_testing(self, beam=3, max_len=128, function_arities=None, eqn_sgn_id=-2, **kwargs):
        # Encode the problem texts, update keyword arguments and apply the decoder.
        text = self.text_model(**kwargs)

        if self.is_tuple_type:
            result = self._generate_memories(text, max_len=max_len, beam=beam, function_arities=function_arities)
        else:
            result = self._generate_seq2seq(text, max_len=max_len, beam=beam, eq_sgn_id=eqn_sgn_id)

        # To ensure having the same size between data-parallel execution, pad it to maximum size.
        shape = list(result.shape)
        seq_len = shape[2]
        if seq_len < max_len:
            shape[2] = max_len

            tensor = torch.full(shape, fill_value=PAD_ID, dtype=torch.long)
            tensor[:, :, :seq_len] = result.cpu()
            result = tensor

        return result.to(text[IN_TXT].device)

    def _generate_memories(self, text: Dict[str, torch.Tensor], max_len=128, beam=3, function_arities=None):
        # Constants
        batch_sz = text[IN_TXT].shape[0]
        batch_range = range(batch_sz)
        device = text[IN_TXT].device
        arity = self.eqn_model.max_arity

        if self.required_field == TUPLE_GEN:
            # Treat all arguments as constants
            num_range = lambda n: 1 <= n < 1 + NUM_MAX
            con_range = lambda n: n == 0 or 1 + NUM_MAX + MEM_MAX <= n
            num_offset = mem_offset = con_offset = 0
        else:
            con_offset = 0
            num_offset = self.eqn_model.constant_word_size
            mem_offset = num_offset + text[IN_TNUM].shape[1]

            con_range = lambda n: n < num_offset
            num_range = lambda n: num_offset <= n < mem_offset

        function_arities = {} if function_arities is None else function_arities

        # Prepare inputs.
        # Initially, we'll start with only one beam, [B, M=1, T=1, 1+2A].
        init = [FUN_NEW_EQN_ID] + [PAD_ID] * (2 * arity)
        result = torch.tensor([[[init]] for _ in batch_range], dtype=torch.long)

        # Prepare beam score storage. [B, M=1]
        beamscores = torch.zeros(batch_sz, 1)

        # Prepare exit indicator
        all_exit = False
        seq_len = 1
        while seq_len < max_len and not all_exit:
            # Compute scores
            scores = self.eqn_model(**text, equation=result.to(device))

            # Retrieve score of the last token. [B, M, T, ?] -> [B, M, ?]
            scores = {key: score[:, :, -1].cpu().detach() for key, score in scores.items()}

            # Probability score for each beam & function words. [B, M, V] + [B, M, 1] = [B, M, V]
            beam_function_score = scores['func'] + beamscores.unsqueeze(-1)

            # Prepare next results
            next_beamscores = torch.zeros(batch_sz, beam)
            next_result = torch.full((batch_sz, beam, seq_len + 1, 1 + 2 * arity), fill_value=PAD_ID, dtype=torch.long)

            beam_range = range(beam_function_score.shape[1])
            fun_range = range(beam_function_score.shape[2])
            for i in batch_range:
                # Compute scores for (Token, Number, Variable) combinations. We will add all log probabilities.
                score_i = []
                for m in beam_range:
                    last_item = result[i, m, -1, 0].item()
                    after_last = last_item in {PAD_ID, FUN_END_EQN_ID}

                    if after_last:
                        # Ignore all tokens other than EOS if previous token was a PAD or EOS.
                        score_i.append((beamscores[i, m].item(), m, PAD_ID, []))
                    else:
                        fun_scores = {}
                        for f in fun_range:
                            fun_score = beam_function_score[i, m, f].item()

                            if f < len(FUN_TOKENS):
                                if f == FUN_END_EQN_ID and last_item == FUN_NEW_EQN_ID:
                                    # Don't permit sequence like [__NEW_EQN, __END_EQN]
                                    continue
                                # BOS, EOS, NEW_VAR token does not require any arguments.
                                score_i.append((fun_score, m, f, []))
                            else:
                                fun_scores[f] = fun_score

                        # Combine argument log-probabilities with function word log-probability.
                        arg_beams = [(0.0, [])]
                        for a in range(arity):
                            # Get top-k result
                            score_ia, index_ia = scores['arg%s' % a][i, m].topk(beam)
                            score_ia = score_ia.tolist()
                            index_ia = index_ia.tolist()

                            # Compute beam*beam combination and preserve only top-beam results.
                            arg_beams = [(s_prev + s_a, arg_prev + [arg_a])
                                         for s_prev, arg_prev in arg_beams
                                         for s_a, arg_a in zip(score_ia, index_ia)]
                            arg_beams = sorted(arg_beams, key=lambda t: t[0], reverse=True)[:beam]

                            for f, s_f in fun_scores.items():
                                # Append function tuples that match current arity.
                                if function_arities.get(f, arity) == a + 1:
                                    score_i += [(s_f + s_args, m, f, args) for s_args, args in arg_beams]

                # Prepare beam tracking. Scores[i] originally have shape [M, T] -> [M * T] after flattening.
                beam_registered = set()
                for score, prevbeam, func, args in sorted(score_i, key=lambda t: t[0], reverse=True):
                    if len(beam_registered) == beam:
                        # If beam is full, exit loop.
                        break

                    beam_signature = (prevbeam, func, *args)
                    if beam_signature in beam_registered:
                        continue

                    newbeam = len(beam_registered)
                    next_beamscores[i, newbeam] = score

                    # Copy tokens
                    next_result[i, newbeam, :-1] = result[i, prevbeam]
                    new_tokens = [func]
                    for j, a in enumerate(args):
                        if con_range(a):
                            new_tokens += [ARG_CON_ID, a - con_offset]
                        elif num_range(a):
                            new_tokens += [ARG_NUM_ID, a - num_offset]
                        else:
                            new_tokens += [ARG_MEM_ID, a - mem_offset]
                    new_tokens = torch.as_tensor(new_tokens, dtype=torch.long, device=device)
                    next_result[i, newbeam, -1, :new_tokens.shape[0]] = new_tokens

                    # Assign beam information
                    beam_registered.add(beam_signature)

            # Copy score information
            beamscores = next_beamscores

            last_tokens = next_result[:, :, -1, 0]
            all_exit = ((last_tokens == PAD_ID) | (last_tokens == FUN_END_EQN_ID)).all().item()

            result = next_result
            seq_len += 1

        return result

    def _generate_seq2seq(self, text: Dict[str, torch.Tensor], max_len=128, beam=3, eq_sgn_id=-2):
        """
        Generate tokens

        :param Dict[str,torch.Tensor] text: Dictionary that contains hidden state of text

        :param int max_len: Maximum length of sequence. 128 by default.
        :param int beam:
            Number of beams. 2 by default. (Used only when we have to generate tokens)

        :rtype: Dict[str, torch.Tensor]
        :return: Dictionary that contains:
            - [OUT_GEN]: Long Tensor of selected equation tokens.
                         Shape [B, M, T] where B = batch size, M = beam, and T = length of sequence.
            - [OUT_NUM]: Long Tensor of number copy pointers. Shape [B, M, T]
            - [OUT_VAR]: Long Tensor of variable pointers. Shape [B, M, T]
        """
        # If we're evaluating this module, we need to generate outputs
        batch_sz = text[IN_TXT].shape[0]
        batch_range = range(batch_sz)
        device = text[IN_TXT].device

        # Prepare inputs.
        # Initially, we'll start with only one beam, [B, M=1, T=1, 3].
        result = torch.tensor([[[[TOK_TOK_ID, SEQ_NEW_EQN_ID, 0]]] for _ in batch_range], dtype=torch.long)

        # Prepare beam score storage. [B, M=1]
        beamscores = torch.zeros(batch_sz, 1)
        fixed_token_sz = self.eqn_model.token_vocab_size

        # Prepare exit indicator
        all_exit = False
        seq_len = 1

        while seq_len < max_len and not all_exit:
            # Compute scores
            # Retrieve score of the last token. [B, M, T, ?] -> [B, M, ?]
            scores = self.eqn_model(**text, equation=result.to(device))
            scores = scores['token'][:, :, -1].cpu().detach()

            # Probability score for each beam & token. [B, M, V] + [B, M, 1] = [B, M, V]
            beam_token_score = scores + beamscores.unsqueeze(-1)

            # Prepare next results
            next_beamscores = torch.zeros(batch_sz, beam)
            next_result = torch.full((batch_sz, beam, seq_len + 1, 3), fill_value=PAD_ID, dtype=torch.long)

            beam_range = range(beam_token_score.shape[1])
            token_range = range(beam_token_score.shape[2])
            for i in batch_range:
                # Compute scores for (Token, Number, Variable) combinations. We will add all log probabilities.
                score_i = []
                for m in beam_range:
                    last_type, last_item, _ = result[i, m, -1].tolist()
                    after_last = last_type == PAD_ID or (last_type == TOK_TOK_ID and last_item == SEQ_END_EQN_ID)

                    if after_last:
                        # Ignore all tokens other than EOS if previous token was a PAD or EOS.
                        score_i.append((beamscores[i, m].item(), m, PAD_ID, PAD_ID, 0))
                    else:
                        for v in token_range:
                            if v == SEQ_END_EQN_ID and (last_type == TOK_TOK_ID and last_item == SEQ_NEW_EQN_ID):
                                # Don't permit sequence like [__NEW_EQN, __END_EQN]
                                continue

                            token_score = beam_token_score[i, m, v].item()
                            if v < fixed_token_sz:
                                # Fixed tokens
                                score_i.append((token_score, m, TOK_TOK_ID, v, int(v == eq_sgn_id)))
                            else:
                                # Numbers
                                score_i.append((token_score, m, TOK_NUM_ID, v - fixed_token_sz, 0))

                # Prepare beam tracking. Scores[i] originally have shape [M, T] -> [M * T] after flattening.
                beam_registered = set()
                for score, prevbeam, typ, token, is_eq_sign in sorted(score_i, key=lambda t: t[0], reverse=True):
                    if len(beam_registered) == beam:
                        # If beam is full, exit loop.
                        break

                    if (prevbeam, token, typ, token) in beam_registered:
                        continue

                    newbeam = len(beam_registered)
                    next_beamscores[i, newbeam] = score

                    # Copy tokens
                    next_result[i, newbeam, :-1] = result[i, prevbeam]
                    next_result[i, newbeam, -1, 0] = typ
                    next_result[i, newbeam, -1, 1] = token
                    next_result[i, newbeam, -1, 2] = is_eq_sign

                    # Assign beam information
                    beam_registered.add((prevbeam, token, typ, token))

            # Copy score information
            beamscores = next_beamscores

            last_tokens = next_result[:, :, -1]
            all_exit = ((last_tokens[0] == PAD_ID) |
                        ((last_tokens[0] == TOK_TOK_ID) & (last_tokens[1] == SEQ_END_EQN_ID))).all().item()

            result = next_result
            seq_len += 1

        return result

    def forward(self, execute=None, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Do forward pass.
        .. see::
            _forward_training or _forward_testing for further detail.
        """
        if self.training:
            return self._forward_training(**kwargs)
        else:
            with torch.no_grad():
                if execute is None:
                    return self._forward_testing(**kwargs)
                else:
                    return getattr(self, execute)(**kwargs)

    def save_pretrained(self, save_directory: str):
        """
        Save current state of Text Model.

        :param str save_directory: String that represents path to the directory where this will be saved.
        """

        # Save text model
        self.text_model.save_pretrained(save_directory)
        # Save equation model
        self.eqn_model.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, config: ModelConfig):
        # Load text model
        text_model = TextModel.from_pretrained(config)

        # Read equation models
        solver_kwargs = {config.model_type: SUBMODULE_TYPES[config.model_type].from_pretrained(config)}

        # Return solver instance.
        return Solver(text_model, **solver_kwargs)


__all__ = ['Solver']
