from pathlib import Path
from typing import Dict, Tuple

import torch
from torch import nn
from transformers.modeling_bert import gelu_new as gelu_bert

from page.config import ModelConfig
from page.const import *
from .attention import *
from .embed import PositionalEncoding
from .layer import *
from .loss import loss_and_accuracy
from .mask import mask_forward
from .util import *


class DecoderModel(nn.Module):
    """
    Model for equation generation/classification (Abstract class)
    """

    def __init__(self, config: ModelConfig):
        """
        Initiate Equation Builder instance

        :keyword int embedding_dim:
            Dimension of embedding vector. 128 by default.
        :keyword int hidden_dim:
            Dimension of hidden vector. 768 by default.
        :keyword int intermediate_dim:
            Dimension of intermediate feed-forward layer in a transformer layer. 3072 by default.
        :keyword int num_hidden_layers:
            Number of repetition for applying the same transformer layer. 12 by default.
        :keyword int num_heads:
            Number of heads in a transformer layer. This does not be applied on single-head attentions.
            12 by Default.
        :keyword float layernorm_eps:
            Epsilon to avoid zero-division in LayerNorm. 1E-12 by default.
        :keyword float init_factor:
            Standard deviation of normal distribution that will be used for initializing weights. 0.02 by default.
        """
        super().__init__()
        # Save configuration.
        self.config = config

    @property
    def embedding_dim(self) -> int:
        """
        :rtype: int
        :return: Dimension of embedding vector
        """
        return self.config.embedding_dim

    @property
    def hidden_dim(self) -> int:
        """
        :rtype: int
        :return: Dimension of hidden vector.
        """
        return self.config.hidden_dim

    @property
    def num_hidden_layers(self) -> int:
        """
        :rtype: int
        :return: Number of repetition for applying the same transformer layer
        """
        return self.config.num_decoder_layers

    @property
    def init_factor(self) -> float:
        """
        :rtype: float
        :return: Standard deviation of normal distribution that will be used for initializing weights.
        """
        return self.config.init_factor

    @property
    def layernorm_eps(self) -> float:
        """
        :rtype: float
        :return: Epsilon to avoid zero-division in LayerNorm.
        """
        return self.config.layernorm_eps

    @property
    def num_heads(self) -> int:
        """
        :rtype: int
        :return: Number of heads in a transformer layer.
        """
        return self.config.num_decoder_heads

    @property
    def num_pointer_heads(self) -> int:
        """
        :rtype: int
        :return: Number of heads in the last pointer layer.
        """
        return self.config.num_pointer_heads

    def save_pretrained(self, save_directory: str):
        """
        Save current state of Equation Builder.

        :param str save_directory: String that represents path to the directory where this will be saved.
        """
        # Write state dictionary
        self.config.save_pretrained(save_directory)
        torch.save(self.state_dict(), Path(save_directory, '%s.pt' % self.__class__.__name__))

    @classmethod
    def from_pretrained(cls, config: ModelConfig):
        # Import the model if available, otherwise create it using keyword argument
        model = cls(config)
        if config.chkpt_path is not None:
            model.load_state_dict(torch.load(Path(config.chkpt_path, '%s.pt' % cls.__name__)))

        return model

    def _init_weights(self, module: nn.Module):
        """
        Initialize weights

        :param nn.Module module: Module to be initialized.
        """
        if isinstance(module, (nn.Linear, nn.Embedding, nn.MultiheadAttention)):
            # nn.Linear has 'weight' and 'bias', nn.Embedding has 'weight',
            # and nn.MultiheadAttention has *_weight and *_bias
            for name, param in module.named_parameters():
                if param is None:
                    continue

                if 'weight' in name:
                    param.data.normal_(mean=0.0, std=self.init_factor)
                elif 'bias' in name:
                    param.data.zero_()
                else:
                    raise NotImplementedError("This case is not considered!")
        elif isinstance(module, nn.LayerNorm):
            # Initialize layer normalization as an identity funciton.
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def _forward_single(self, **kwargs) -> Dict[str, torch.Tensor]:
        raise NotImplementedError()

    def _build_target_dict(self, **kwargs) -> Dict[str, torch.Tensor]:
        raise NotImplementedError()

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
        result = {}
        if self.training:
            # Forward single beam input.
            output = self._forward_single(**kwargs)
            # Build targets
            with torch.no_grad():
                targets = self._build_target_dict(**kwargs)

            for key in targets:
                result.update(loss_and_accuracy(output[key], shift_target(targets[key]), prefix='Train_%s' % key))
        else:
            with torch.no_grad():
                result.update(apply_across_dim(self._forward_single, dim=1,
                                               shared_keys={IN_TXT, IN_TPAD, IN_TNUM, IN_TNPAD}, **kwargs))

        return result


class OperationDecoderModel(DecoderModel):
    def __init__(self, config):
        super().__init__(config)

        """ Embedding layers """
        # (1) Function word embedding
        self.function_word_embedding = nn.Embedding(self.function_word_size, self.hidden_dim)
        self.function_pos_embedding = PositionalEncoding(self.hidden_dim, fixed_embedding=True)
        self.degrade_factor = self.embedding_dim ** 0.5

        # (2) Argument word embeddings
        # Embedding will be defined in sub-classes
        self.argument_type_embedding = nn.Embedding(3, self.hidden_dim)

        self.function_pos_factor = nn.Parameter(torch.tensor(self.degrade_factor), requires_grad=True)
        self.argument_ctx_factor = nn.Parameter(torch.tensor(self.degrade_factor), requires_grad=True)

        self.function_norm = nn.LayerNorm(self.hidden_dim, eps=self.layernorm_eps)
        self.argument_norm = nn.LayerNorm(self.hidden_dim, eps=self.layernorm_eps)

        # (3) Transformation embedding-hidden layer.
        self.word_to_hidden = nn.Linear(self.hidden_dim * (self.max_arity + 1), self.hidden_dim)
        # self.word_hidden_norm = nn.LayerNorm(self.hidden_dim, eps=self.layernorm_eps)

        """ Transformer layer """
        # Shared transformer layer for decoding
        self.shared_layer = TransformerLayer(config)

        """ Generator/Pointer layer """
        self.function_out = nn.Linear(self.hidden_dim, self.function_word_size)
        self.softmax = LogSoftmax(dim=-1)

        # Output will be defined in sub-classes
        # Initialize will be done in sub-classes

    @property
    def function_word_size(self):
        return self.config.function_word_size

    @property
    def argument_word_size(self):
        return self.config.argument_word_size

    @property
    def constant_word_size(self):
        return self.config.constant_word_size

    @property
    def max_arity(self):
        return max([op['arity'] for op in OPERATORS.values()], default=2)

    def _build_argument_embed(self, ids: torch.Tensor, mem_pos: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

    def _build_decoder_input(self, ids: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        # Ids: [B, T, 1 + 2*A], where A = Max arity
        # Function embedding: [B, T, E]
        fun_embed = get_embedding_without_pad(self.function_word_embedding, ids.select(dim=-1, index=0))
        # Function positions: [T, E]
        fun_pos = self.function_pos_embedding(ids.shape[1])

        # Make function embedding to [B, T, 1, E]
        fun_embed = self.function_norm(fun_embed * self.function_pos_factor + fun_pos.unsqueeze(0)).unsqueeze(2)

        # Argument embedding [B, T, A, E]
        arg_embed = get_embedding_without_pad(self.argument_type_embedding, ids[:, :, 1::2]) * self.argument_ctx_factor
        arg_embed += self._build_argument_embed(ids, fun_pos, nums)
        arg_embed = self.argument_norm(arg_embed)

        # Concatenate embedding: [B, T, 1+A, E] -> [B, T, (1+A)E]
        embedding = torch.cat([fun_embed, arg_embed], dim=2).contiguous().flatten(start_dim=2)
        # return self.word_hidden_norm(self.word_to_hidden(embedding))
        return self.word_to_hidden(embedding)

    def _build_decoder_context(self, embedding: torch.Tensor, embedding_pad: torch.Tensor = None,
                               text: torch.Tensor = None, text_pad: torch.Tensor = None) -> torch.Tensor:
        # embedding: [B, T, H]
        mask = mask_forward(embedding.shape[1]).to(embedding.device)
        output = embedding
        for _ in range(self.num_hidden_layers):
            output = self.shared_layer(target=output, memory=text, target_attention_mask=mask,
                                       target_ignorance_mask=embedding_pad, memory_ignorance_mask=text_pad)

        return output

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None, text_num: torch.Tensor = None,
                        equation: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]:
        # Embedding: [B, T, H]
        function_word = equation.select(dim=2, index=0)
        output = self._build_decoder_input(ids=equation, nums=text_num)
        output_pad = function_word == PAD_ID

        # Ignore the result of equality at the function output
        output_not_usable = output_pad.clone()
        output_not_usable[:, :-1].masked_fill_(function_word[:, 1:] == FUN_EQ_SGN_ID, True)
        # We need offset '1' because 'function_word' is input and output_not_usable is 1-step shifted output.

        # Decoder output: [B, T, H]
        output = self._build_decoder_context(embedding=output, embedding_pad=output_pad, text=text, text_pad=text_pad)

        # Compute function output (with 'NEW_EQN' masked)
        func_output = self.function_out(output)

        if not self.training:
            func_output[:, :, FUN_NEW_EQN_ID] = NEG_INF
            # Can end after equation formed, i.e. END_EQN is available when the input is EQ_SGN.
            func_output[:, :, FUN_END_EQN_ID].masked_fill_(function_word != FUN_EQ_SGN_ID, NEG_INF)

        # Predict function output.
        result = {'func': self.softmax(func_output),
                  '_out': output, '_not_usable': output_not_usable}

        # Remaining work will be done by subclasses
        return result


class OperationAlbertTransformer(OperationDecoderModel):
    def __init__(self, config):
        super().__init__(config)

        """ Embedding layers """
        # (1) Function word embedding: initialized by super-classes
        # (2) Argument word embeddings
        # Order: Const(+ Number), Memory
        # argument_type_embedding is not required (single type: ARG_CON)
        self.argument_word_embedding = nn.Embedding(self.argument_word_size, self.hidden_dim)
        # self.argument_word_norm: initialized by super-class
        # (3) Transformation embedding-hidden layer: initialized by super-class

        """ Transformer layer """
        # self.shared_layer : initialized by super-class

        """ Generator/Pointer layer """
        # self.function_out: initialized by super-class

        # Predict argument by classification
        self.argument_out = nn.ModuleList([
            nn.ModuleDict({
                '0_out': nn.Linear(self.hidden_dim, self.argument_word_size)
            }) for _ in range(self.max_arity)
        ])

        """ Initialize weights """
        with torch.no_grad():
            # Initialize Linear, LayerNorm, Embedding
            self.apply(self._init_weights)

    def _build_argument_embed(self, ids: torch.Tensor, mem_pos: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        # Argument word embedding: [B, T, A, E]
        return get_embedding_without_pad(self.argument_word_embedding, ids[:, :, 2::2])

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None,
                        text_num: torch.Tensor = None, text_numpad: torch.Tensor = None,
                        equation: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]:
        result = super()._forward_single(text, text_pad, text_num, equation)

        # Hidden state: [B, T, H]
        output = result.pop('_out')
        # Usability Mask: [B, T] -> [B, 1, T]
        output_not_usable = result.pop('_not_usable').unsqueeze(1)
        # Forward mask: [T, T] -> [1, T, T]
        forward_mask = mask_forward(output.shape[1], diagonal=0).unsqueeze(0).to(output.device)

        # Number tokens are placed on 1:1+NUM_MAX
        num_begin = 1
        num_used = num_begin + min(text_num.shape[1], NUM_MAX)
        num_end = num_begin + NUM_MAX
        # Memory tokens are placed on 1+NUM_MAX:1+NUM_MAX+MEM_MAX
        mem_used = num_end + min(output.shape[1], MEM_MAX)
        mem_end = num_end + MEM_MAX

        # Predict arguments
        for i, layer in enumerate(self.argument_out):
            word_output = apply_module_dict(layer, encoded=output)

            if not self.training:
                # Ignore probabilities on not-appeared number tokens
                word_output[:, :, num_begin:num_used].masked_fill_(text_numpad.unsqueeze(1), NEG_INF)
                word_output[:, :, num_used:num_end] = NEG_INF

                # Ignore probabilities on non-appeared memory tokens
                word_output[:, :, num_end:mem_used].masked_fill_(output_not_usable, NEG_INF)
                word_output[:, :, num_end:mem_used].masked_fill_(forward_mask, NEG_INF)
                word_output[:, :, mem_used:mem_end] = NEG_INF

            # Apply softmax after masking
            result['arg%s' % i] = self.softmax(word_output)

        return result

    def _build_target_dict(self, **kwargs) -> Dict[str, torch.Tensor]:
        # Build targets
        equation = kwargs[IN_EQN]
        targets = {'func': equation.select(dim=-1, index=0)}
        for i in range(self.max_arity):
            targets['arg%s' % i] = equation[:, :, (i * 2 + 2)]

        return targets


class OperationPointerAlbertTransformer(OperationDecoderModel):
    def __init__(self, config):
        super().__init__(config)

        """ Embedding layers """
        # (1) Function word embedding: initialized by super-class
        # (2) Argument word embeddings
        # self.number_word_embedding = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.constant_word_embedding = nn.Embedding(self.constant_word_size, self.hidden_dim)

        # self.argument_word_norm: initialized by super-class
        # (3) Transformation embedding-hidden layer: initialized by super-class

        """ Transformer layer """
        # self.shared_layer: initialized by super-class

        """ Generator/Pointer layer """
        # self.function_out: initialized by super-class

        # Predict argument by pointing
        self.argument_out = nn.ModuleList([
            nn.ModuleDict({
                '0_attn': MultiheadAttentionWeights(hidden_dim=self.hidden_dim, num_heads=self.num_pointer_heads),
                '1_mean': Squeeze(dim=-1) if self.num_pointer_heads == 1 else AveragePooling(dim=-1)
            }) for _ in range(self.max_arity)
        ])

        """ Initialize weights """
        with torch.no_grad():
            # Initialize Linear, LayerNorm, Embedding
            self.apply(self._init_weights)

    def _build_argument_embed(self, ids: torch.Tensor, mem_pos: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        # Number embedding. [B, N, H] -> [B, N, E]
        number_embed = nums

        # Argument indices [B, T, A]
        arg_types = ids[:, :, 1::2]
        arg_value = ids[:, :, 2::2]

        # Build number embedding: [B, T, A, E]
        arg_num_index = arg_value.masked_fill(arg_types != ARG_NUM_ID, PAD_ID)
        arg_embed = torch.stack([get_embedding_without_pad(number_embed[b], arg_num_index[b])
                                 for b in range(ids.shape[0])], dim=0).contiguous()

        # Add constant embedding: [B, T, A, E]
        arg_embed += get_embedding_without_pad(self.constant_word_embedding,
                                               arg_value.masked_fill(arg_types != ARG_CON_ID, PAD_ID))

        # Add memory embedding: [B, T, A, E]
        arg_mem_index = arg_value.masked_fill(arg_types != ARG_MEM_ID, PAD_ID)
        arg_embed += get_embedding_without_pad(mem_pos, arg_mem_index)
        return arg_embed

    def _build_attention_keys(self, num: torch.Tensor, mem: torch.Tensor, num_pad: torch.Tensor = None,
                              mem_pad: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_sz = num.shape[0]
        const_sz = self.constant_word_size
        const_num_sz = const_sz + num.shape[1]

        # Order: Const, Number, Memory
        # Constant keys: [C, E] -> [1, C, H] -> [B, C, H]
        const_key = self.constant_word_embedding.weight.unsqueeze(0).expand(batch_sz, const_sz, self.hidden_dim)

        # Key: [B, C+N+M, H]
        key = torch.cat([const_key, num, mem], dim=1).contiguous()
        # Key ignorance mask: [B, C+N+M]
        key_ignorance_mask = torch.zeros(key.shape[:2], dtype=torch.bool, device=key.device)
        if num_pad is not None:
            key_ignorance_mask[:, const_sz:const_num_sz] = num_pad
        if mem_pad is not None:
            key_ignorance_mask[:, const_num_sz:] = mem_pad

        # Attention mask: [M, C+N+M], exclude self.
        attention_mask = torch.zeros(mem.shape[1], key.shape[1], dtype=torch.bool, device=key.device)
        attention_mask[:, const_num_sz:] = mask_forward(mem.shape[1], diagonal=0).to(key_ignorance_mask.device)

        return key, key_ignorance_mask, attention_mask

    def _build_attention_target(self, argument_slice: torch.Tensor, max_num: int) -> torch.Tensor:
        arg_types = argument_slice.select(dim=-1, index=0)
        arg_value = argument_slice.select(dim=-1, index=1).clamp_min(0)

        num_offset = self.constant_word_size
        mem_offset = num_offset + max_num

        # Add index offsets. -1 if PAD_ID, 0 if ARG_CON, C if ARG_NUM, N+C if ARG_MEM.
        arg_value += arg_types.masked_fill(arg_types == ARG_NUM_ID, num_offset) \
            .masked_fill_(arg_types == ARG_MEM_ID, mem_offset)

        return arg_value

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None,
                        text_num: torch.Tensor = None, text_numpad: torch.Tensor = None,
                        equation: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]:
        result = super()._forward_single(text, text_pad, text_num, equation)

        # Decoder output: [B, T, H]
        output = result.pop('_out')
        output_not_usable = result.pop('_not_usable')

        # Build attention keys
        key, key_ign_msk, attn_msk = self._build_attention_keys(num=text_num, mem=output,
                                                                num_pad=text_numpad, mem_pad=output_not_usable)

        # Predict arguments
        for i, layer in enumerate(self.argument_out):
            score = apply_module_dict(layer, encoded=output, key=key, key_ignorance_mask=key_ign_msk,
                                      attention_mask=attn_msk)
            result['arg%s' % i] = self.softmax(score)

        return result

    def _build_target_dict(self, **kwargs) -> Dict[str, torch.Tensor]:
        # Build targets
        equation = kwargs[IN_EQN]

        # Offsets
        num_offset = self.constant_word_size
        mem_offset = num_offset + kwargs[IN_TNUM].shape[1]

        targets = {'func': equation.select(dim=-1, index=0)}
        for i in range(self.max_arity):
            arg_types = equation[:, :, (i * 2 + 1)]
            arg_value = equation[:, :, (i * 2 + 2)].clamp_min(0)

            # Add index offsets. -1 if PAD_ID, 0 if ARG_CON, C if ARG_NUM, N+C if ARG_MEM.
            arg_value += arg_types.masked_fill(arg_types == ARG_NUM_ID, num_offset) \
                .masked_fill_(arg_types == ARG_MEM_ID, mem_offset)

            targets['arg%s' % i] = arg_value

        return targets


class TokenDecoderModel(DecoderModel):
    """
    Build equations from given encoded vector.
    """

    def __init__(self, config):
        """
        Initiate Equation Builder instance

        :keyword int embedding_dim:
            Dimension of embedding vector. 128 by default.
        :keyword int hidden_dim:
            Dimension of hidden vector. 768 by default.
        :keyword int intermediate_dim:
            Dimension of intermediate feed-forward layer in a transformer layer. 3072 by default.
        :keyword int num_hidden_layers:
            Number of repetition for applying the same transformer layer. 12 by default.
        :keyword int num_heads:
            Number of heads in a transformer layer. This does not be applied on single-head attentions.
            12 by Default.
        :keyword float layernorm_eps:
            Epsilon to avoid zero-division in LayerNorm. 1E-12 by default.
        :keyword float initializer_range:
            Standard deviation of normal distribution that will be used for initializing weights. 0.02 by default.
        :keyword bool memory_modifiable:
            True if memory modification is enabled on decoder. True by default.
        """
        super().__init__(config)

        """ Word embedding """
        self.word_embedding = nn.Embedding(self.token_vocab_size, self.hidden_dim)
        # Positional encoding
        self.pos_embedding = PositionalEncoding(self.hidden_dim, fixed_embedding=True)
        # LayerNorm for normalizing word embedding vector.
        self.word_hidden_norm = nn.LayerNorm(self.hidden_dim, eps=self.layernorm_eps)
        # Factor that upweights word embedding vector.
        self.degrade_factor = self.hidden_dim ** 0.5
        self.pos_factor = nn.Parameter(torch.tensor(self.degrade_factor), requires_grad=True)

        """ Transformer layer """
        # Shared transformer layer for decoding
        self.shared_layer = TransformerLayer(config)

        # Weight will be initialized by sub-classes

    @property
    def token_vocab_size(self):
        raise NotImplementedError()

    def _build_word_embed(self, ids: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

    def _build_decoder_input(self, ids: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        # Ids: [B, T, 2]
        # Positions: [T, E]
        pos = self.pos_embedding(ids.shape[1])
        # Word embeddings: [B, T, E]
        word = self._build_word_embed(ids, nums)
        # Return [B, T, E]
        return self.word_hidden_norm(word * self.pos_factor + pos.unsqueeze(0))

    def _build_decoder_context(self, embedding: torch.Tensor, embedding_pad: torch.Tensor = None,
                               text: torch.Tensor = None, text_pad: torch.Tensor = None) -> torch.Tensor:
        # embedding: [B, T, H]
        mask = mask_forward(embedding.shape[1]).to(embedding.device)
        output = embedding
        for _ in range(self.num_hidden_layers):
            output = self.shared_layer(target=output, memory=text, target_attention_mask=mask,
                                       target_ignorance_mask=embedding_pad, memory_ignorance_mask=text_pad)

        return output

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None, text_num: torch.Tensor = None,
                        equation: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]:
        # Embedding: [B, T, H]
        output = self._build_decoder_input(ids=equation, nums=text_num.relu())
        output_pad = equation.select(dim=2, index=0) == PAD_ID

        # Decoder output: [B, T, H]
        output = self._build_decoder_context(embedding=output, embedding_pad=output_pad, text=text, text_pad=text_pad)
        result = {'_out': output}

        # Remaining work will be done by subclasses
        return result


class BaselineAlbertTransformer(TokenDecoderModel):
    """
    Generating equations from given encoded vector.
    """

    def __init__(self, config):
        super().__init__(config)

        """ Template Token Generator """
        self.token = nn.Linear(self.hidden_dim, self.token_vocab_size)
        self.softmax = LogSoftmax(dim=-1)

        """ Initialize weights """
        with torch.no_grad():
            # Initialize Linear, LayerNorm, Embedding
            self.apply(self._init_weights)

    @property
    def token_vocab_size(self):
        return self.config.token_vocab_size

    def _build_word_embed(self, ids: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        # All will be tokens.
        return get_embedding_without_pad(self.word_embedding, ids.select(dim=2, index=1))

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None, text_num: torch.Tensor = None,
                        text_numpad: torch.Tensor = None, equation: torch.Tensor = None,
                        **kwargs) -> Dict[str, torch.Tensor]:
        result = super()._forward_single(text, text_pad, text_num, equation)
        output = result.pop('_out')

        # number end position
        num_used = SEQ_GEN_NUM_ID + text_num.shape[1]

        # Predict tokens: Shape [B, S, V].
        tokens = self.token(output)

        if not self.training:
            # Ignore NEW_EQN token during prediction.
            tokens[:, :, SEQ_NEW_EQN_ID] = NEG_INF
            # Ignore probabilities on not-appeared number tokens
            tokens[:, :, SEQ_GEN_NUM_ID:num_used].masked_fill_(text_numpad.unsqueeze(1), NEG_INF)
            tokens[:, :, num_used:SEQ_GEN_VAR_ID] = NEG_INF

            # Can end after equation formed, i.e. END_EQN is available when the input is '='.
            tokens[:, :, SEQ_END_EQN_ID].masked_fill_(~equation.select(dim=-1, index=2).bool(), NEG_INF)

        result['token'] = self.softmax(tokens)

        return result

    def _build_target_dict(self, **kwargs) -> Dict[str, torch.Tensor]:
        # Build targets
        equation = kwargs[IN_EQN]
        return {'token': equation.select(dim=-1, index=1)}


class PointerAlbertTransformer(TokenDecoderModel):
    """
    Generating equations from given encoded vector.
    """

    def __init__(self, config):
        super().__init__(config)

        self.ctx_factor = nn.Parameter(torch.tensor(1/self.degrade_factor), requires_grad=True)

        """ Input """
        self.number_word_embedding = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.token_type_embedding = nn.Embedding(2, self.hidden_dim)

        """ Token Generator """
        self.token = nn.ModuleDict({
            '0_attn': MultiheadAttentionWeights(hidden_dim=self.hidden_dim, num_heads=self.num_pointer_heads),
            '1_mean': Squeeze(dim=-1) if self.num_pointer_heads == 1 else AveragePooling(dim=-1),
            '2_logsoftmax': LogSoftmax(dim=-1)
        })

        """ Initialize weights """
        with torch.no_grad():
            # Initialize Linear, LayerNorm, Embedding
            self.apply(self._init_weights)

    @property
    def token_vocab_size(self):
        return self.config.token_nonum_size

    def _build_word_embed(self, ids: torch.Tensor, nums: torch.Tensor) -> torch.Tensor:
        types = ids.select(dim=2, index=0)
        ids = ids.select(dim=2, index=1)
        num_embeddings = self.number_word_embedding(nums)

        type_embed = get_embedding_without_pad(self.token_type_embedding, types)
        token_embed = get_embedding_without_pad(self.word_embedding, ids.masked_fill(types == TOK_NUM_ID, PAD_ID))

        num_index = ids.masked_fill(types != TOK_NUM_ID, PAD_ID)
        token_embed += torch.stack([get_embedding_without_pad(num_embeddings[b], num_index[b])
                                    for b in range(ids.shape[0])], dim=0).contiguous()

        return type_embed + token_embed * self.ctx_factor

    def _build_attention_keys(self, num: torch.Tensor, num_pad: torch.Tensor = None) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        batch_sz = num.shape[0]
        token_sz = self.token_vocab_size

        # Order: Const, Number, Memory
        # Constant keys: [C, E] -> [1, C, H] -> [B, C, H]
        token_key = self.word_embedding.weight.unsqueeze(0).expand(batch_sz, token_sz, self.hidden_dim)

        # Key: [B, C+N, H]
        key = torch.cat([token_key, num], dim=1).contiguous()
        # Key ignorance mask: [B, C+N]
        key_ignorance_mask = torch.zeros(key.shape[:2], dtype=torch.bool, device=key.device)
        if num_pad is not None:
            key_ignorance_mask[:, token_sz:] = num_pad

        return key, key_ignorance_mask

    def _build_attention_target(self, equation: torch.Tensor) -> torch.Tensor:
        types = equation.select(dim=-1, index=0)
        value = equation.select(dim=-1, index=1).clamp_min(0)

        # Add index offsets. -1 if PAD_ID, 0 if ARG_CON, C if ARG_NUM, N+C if ARG_MEM.
        value += types.masked_fill(types == TOK_NUM_ID, self.token_vocab_size)
        return value

    def _forward_single(self, text: torch.Tensor = None, text_pad: torch.Tensor = None, text_num: torch.Tensor = None,
                        text_numpad: torch.Tensor = None, equation: torch.Tensor = None,
                        **kwargs) -> Dict[str, torch.Tensor]:
        result = super()._forward_single(text, text_pad, text_num, equation)
        output = result.pop('_out')

        # Predict tokens: Shape [B, S, V].
        key, key_ign_msk = self._build_attention_keys(text_num, text_numpad)
        result['token'] = apply_module_dict(self.token, encoded=output, key=key, key_ignorance_mask=key_ign_msk)

        return result

    def _build_target_dict(self, **kwargs) -> Dict[str, torch.Tensor]:
        # Build targets
        return {'token': self._build_attention_target(kwargs[IN_EQN])}


__all__ = ['OperationPointerAlbertTransformer', 'OperationAlbertTransformer',
           'PointerAlbertTransformer', 'BaselineAlbertTransformer']
