import math
from typing import Union
from pathlib import Path

import torch
from torch import nn

from page.const import PAD_ID
from .util import get_embedding_without_pad


class PositionalEncoding(nn.Module):
    """
    Positional encoding that extends trigonometric embedding proposed in 'Attention is all you need'
    """

    def __init__(self, embedding_dim, fixed_embedding: bool = False):
        """
        Instantiate positional encoding instance.

        :param int embedding_dim:
            Dimension of embedding vector
        :param bool=False fixed_embedding:
            Set `True` if this should use fixed embedding version of 'Attention is all you need' paper.
        """

        super().__init__()
        #: Dimension of embedding vector
        self.embedding_dim = embedding_dim
        #: Flag whether this positional encoding is fixed or not.
        self.fixed_embedding = fixed_embedding

        # The output will be c_p * cos(a_p * t + b_p) + d_p * sin(a_p * t + b_p), where t=index and p = 1...embed_dim
        if not fixed_embedding:
            # Here, a, b, c, d are learnable parameters.
            # Define a_p and b_p using Linear layer
            self.before_trigonometric_linear = nn.Linear(1, embedding_dim)
            # Define c_p and d_p.
            self.multiplier = nn.Parameter(torch.normal(0, 0.2, size=(2, embedding_dim)), requires_grad=True)
        else:
            # From "Attention is all you need" paper.
            # Here, b_p = 0 and a_2p = a_{2p+1} = 1 / 10000^{2p/embed_dim}.
            # Thus, we need to define a_p only.
            div_term = (torch.arange(0, embedding_dim) // 2) * 2
            div_term = torch.exp(div_term.float() * (-math.log(10000.0) / embedding_dim))
            # Note: c_p = 1 if p is odd, 0 otherwise and d_p = 1 if p is even, 0 otherwise
            multiplier = torch.zeros(2, embedding_dim, dtype=torch.float)
            multiplier[0, 1::2] = 1.0  # Only use cosine for odd indices
            multiplier[1, 0::2] = 1.0  # Only use sine for even indices

            # Fix a_p, c_p, d_p values.
            self.register_buffer('_div_term', div_term)
            self.register_buffer('multiplier', multiplier)

    @property
    def device(self) -> torch.device:
        """
        Get the device where weights are currently put.
        :rtype: torch.device
        :return: Device instance
        """
        return self._div_term.device if self.fixed_embedding else self.before_trigonometric_linear.weight.device

    def before_trigonometric(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Compute a_p * t + b_p for each index t.
        :param torch.Tensor indices: A Long tensor to compute indices.
        :rtype: torch.Tensor
        :return: Tensor whose values are a_p * t + b_p for each (t, p) entry.
        """
        indices = indices.float()

        if self.fixed_embedding:
            # If embedding is fixed, only compute a_p * t.
            return indices * self._div_term
        else:
            # Otherwise, use linear layer to compute a_p * t + b_p.
            return self.before_trigonometric_linear(indices)

    def forward(self, index_or_range: Union[torch.Tensor, int, range], ignored_index=PAD_ID) -> torch.Tensor:
        """
        Compute positional encoding. If this encoding is not learnable, the result cannot have any gradient vector.

        .. math::
            P_{t, p} = c_p * \\cos(a_p * t + b_p) + d_p * \\sin(a_p * t + b_p).

        :param Union[torch.Tensor,int,range] index_or_range:
            Value that represents positional encodings to be built.
            - A Tensor value indicates indices itself.
            - A integer value indicates indices from 0 to the value
            - A range value indicates indices within the range.
        :param int ignored_index: The index to be ignored. `PAD_ID` by default.
        :rtype: torch.Tensor
        :return:
            Positional encoding of given value.
            - If torch.Tensor of shape [*, L] is given, this will have shape [*, L, E] if L is not 1, otherwise [*, E].
            - If integer or range is given, this will have shape [T, E], where T is the length of range.
        """
        if self.fixed_embedding:
            # If this is an fixed_embedding, we don't need to compute gradients.
            with torch.no_grad():
                return self._forward(index_or_range, ignored_index)
        else:
            return self._forward(index_or_range, ignored_index)

    def _forward(self, index_or_range: Union[torch.Tensor, int, range], ignored_index=PAD_ID) -> torch.Tensor:
        """
        Compute positional encoding

        .. math::
            P_{t, p} = c_p * \\cos(a_p * t + b_p) + d_p * \\sin(a_p * t + b_p).

        :param Union[torch.Tensor,int,range] index_or_range:
            Value that represents positional encodings to be built.
            - A Tensor value indicates indices itself.
            - A integer value indicates indices from 0 to the value
            - A range value indicates indices within the range.
        :param int ignored_index: The index to be ignored. `PAD_ID` by default.
        :rtype: torch.Tensor
        :return:
            Positional encoding of given value.
            - If torch.Tensor of shape [*, L] is given, this will have shape [*, L, E] if L is not 1, otherwise [*, E].
            - If integer or range is given, this will have shape [T, E], where T is the length of range.
        """
        if type(index_or_range) is int:
            # Build Long Tensor of [0, ..., index-1]
            indices = torch.arange(0, index_or_range)
        elif type(index_or_range) is range:
            # Build Long Tensor of [range]
            indices = torch.as_tensor(list(index_or_range))
        else:
            indices = index_or_range

        # Unsqueeze the last dimension to pass the linear layer.
        indices = indices.unsqueeze(-1)

        # Send indices to device that currently using.
        indices = indices.to(self.device)

        # Now indices will have shape [*, 1], we can apply the linear layer, a_p * t + b_p.
        phase = self.before_trigonometric(indices)

        # Phase has shape [*, E]. Apply cosine and sine function on the phase.
        cos_value = phase.cos()
        sin_value = phase.sin()

        # Retrieve c_p and d_p vectors. These have shape [E].
        cos_multiplier = self.multiplier[0]
        sin_multiplier = self.multiplier[1]

        # To multiply c_p and d_p on [*, E], unsqueeze c_p and d_p to fit [*].
        # Make the dimension of c_p the same
        result_shape = [1] * (phase.dim() - 1) + [-1]
        cos_multiplier = cos_multiplier.view(*result_shape)
        sin_multiplier = sin_multiplier.view(*result_shape)

        # Compute c_p * cos(phase) + d_p * sin(phase). Shape will be [*, E].
        result = cos_value * cos_multiplier + sin_value * sin_multiplier

        # Fill ignored indices as zero.
        ignored_indices = (indices == ignored_index)
        if ignored_indices.any():
            result.masked_fill_(ignored_indices, 0.0)

        # Return value. Shape [*, E]
        return result.contiguous()


class EquationEmbedding(nn.Module):
    """
    Embedding for encoding current state of given equation.
    """

    def __init__(self, fixed_embedding=True, **config):
        """
        Instantiate embedding instance.

        :keyword int embedding_dim:
            Dimension of embedding vector. 128 by default.
        :keyword int token_vocab_size:
            Size of vocabulary for representing an equation. 100 by default.
        :keyword float layernorm_eps:
            Epsilon to avoid zero-division in LayerNorm. 1E-12 by default.
        :keyword float init_factor:
            Standard deviation of normal distribution that will be used for initializing weights. 0.02 by default
        """

        super().__init__()
        self.config = config

        # Word embedding
        self.word_embedding = nn.Embedding(self.equation_vocab, self.embedding_dim)
        # Positional encoding
        self.pos_embedding = PositionalEncoding(self.embedding_dim, fixed_embedding=fixed_embedding)
        # LayerNorm for normalizing word embedding vector.
        self.word_norm = nn.LayerNorm(self.embedding_dim, eps=self.layernorm_eps)
        # Factor that upweights word embedding vector.
        self.position_upweight_factor = math.sqrt(self.embedding_dim)

        # Initialize weights
        with torch.no_grad():
            self.apply(self._init_weights)

    def forward(self, token: torch.Tensor, ignored_index=PAD_ID) -> torch.Tensor:
        """
        Convert tokens to embedding vector.

        :param torch.Tensor token:
            A Long Tensor that represents token indices.
            Shape [B, T] or [B, M, T], where B = batch size, M = beam size, T = token sequence length.
        :param int ignored_index: The index to be ignored. `PAD_ID` by default.
        :rtype: torch.Tensor
        :return:
            A Float Tensor of embedding with positional encoding. Shape [B, T, E] or [B, M, T, E].
        """
        seq_len = token.shape[-1]
        shape = (1,) * (token.dim() - 1) + (seq_len, self.embedding_dim)

        # equation_tokens: [B, T] or [B, M, T] --> token_embedding: [B, T, E] or [B, M, T, E]
        token_embedding = get_embedding_without_pad(self.word_embedding, token)
        # token_positions: [T, E] --> [1, T, E]
        token_positions = self.pos_embedding(seq_len, ignored_index=ignored_index).view(*shape)
        # Build token embedding with positions
        token_embedding = token_embedding + token_positions / self.position_upweight_factor

        # Return [B, T, E] or [B, M, T, E].
        return self.word_norm(token_embedding)

    @property
    def embedding_dim(self) -> int:
        """
        :rtype: int
        :return: Dimension of embedding vector
        """
        return self.config.get('embedding_dim', 128)

    @property
    def equation_vocab(self) -> int:
        """
        :rtype: int
        :return: Size of vocabulary for representing an equation.
        """
        return self.config.get('token_vocab_size', 100)

    @property
    def layernorm_eps(self) -> float:
        """
        :rtype: float
        :return: Epsilon to avoid zero-division in LayerNorm.
        """
        return self.config.get('layernorm_eps', 1E-12)

    @property
    def init_factor(self) -> float:
        """
        :rtype: float
        :return: Standard deviation of normal distribution that will be used for initializing weights.
        """
        return self.config.get('init_factor', 0.02)

    def _init_weights(self, module: nn.Module):
        """
        Initialize weights

        :param nn.Module module: Module to be initialized.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # nn.Linear has 'weight' and 'bias', nn.Embedding has 'weight',
            for name, param in module.named_parameters():
                if 'weight' in name:
                    param.data.normal_(mean=0.0, std=self.init_factor)
                else:
                    param.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            # Initialize layer normalization as an identity function.
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def save_pretrained(self, save_directory: str):
        """
        Save current state of Equation Embedding.

        :param str save_directory: String that represents path to the directory where this will be saved.
        """
        # Write state dictionary
        torch.save({
            'config': self.config,
            'state': self.state_dict()
        }, Path(save_directory, 'embedding.pt'))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """
        Build Equation Builder instance from saved checkpoint.

        :keyword int embedding_dim:
            Dimension of embedding vector. 128 by default.
        :keyword int token_vocab_size:
            Size of vocabulary for representing an equation. 100 by default.
        :keyword float layernorm_eps:
            Epsilon to avoid zero-division in LayerNorm. 1E-12 by default.
        :keyword float init_factor:
            Standard deviation of normal distribution that will be used for initializing weights. 0.02 by default
        """

        # Import the model if available, otherwise create it using keyword argument
        path = Path(pretrained_model_name_or_path, 'embedding.pt')
        if path.exists():
            state_dict = torch.load(path)

            # Reconstruct model
            kwargs.update(state_dict['config'])
            model = cls(**kwargs)
            model.load_state_dict(state_dict['state'])

            return model
        else:
            return cls(**kwargs)


__all__ = ['EquationEmbedding', 'PositionalEncoding']
