from typing import Union

import torch
import math

from torch import nn
from pytorch_transformers.modeling_bert import BertLayerNorm

class prefix_shuffle(nn.Module):
    def __init__(self, shuffle_prob:float=0.0):
        super(prefix_shuffle, self).__init__()
        self.shuffle_prob = shuffle_prob
        self.keep_prob = 1 - self.shuffle_prob

    def forward(self, prefixs):

        if (self.training is False) or (self.shuffle_prob == 0.0):
            return prefixs

        mask = torch.zeros(*prefixs.shape[:-1], dtype=torch.bool, device=prefixs.device).bernoulli_(self.shuffle_prob)
        prefixs[mask] = prefixs[mask][...,torch.randperm(prefixs[mask].shape[-2]),:]

        return prefixs

class prefix_drop(nn.Module):
    def __init__(
            self,
            drop_prob:float=0.0,
            fill_values:Union[torch.Tensor,float]=0.0
    ):
        super(prefix_drop, self).__init__()
        # prefix_drop will randomly drop some vectors from prefix sets.
        # fill_values can be embedding vector (e.g., [PAD]) or values (e.g., 0.0)

        self.drop_prob = drop_prob
        self.keep_prob = 1 - self.drop_prob
        self.fill_values = fill_values

    def forward(self, prefixs):
        # prefixs : (Batch, ...)

        if (self.training is False) or (self.drop_prob == 0.0):
            return prefixs

        assert prefixs.shape[-1] == self.fill_values \
               or isinstance(self.fill_values, int) \
               or isinstance(self.fill_values, float), \
               f"The type of self.fill_value is not correct or the size is not correct"

        mask = torch.zeros(*prefixs.shape[:-1], dtype=torch.bool, device=prefixs.device).bernoulli_(self.drop_prob)
        prefixs[mask] = self.fill_values

        return prefixs

class prefix_Embeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(prefix_Embeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        self.prefix_ID = None
        self.freeze_prefix = False
        self.prefix_embedding = None

        # self.tag_adaptive_module = None
        self.prefix_no_pos_emb = None

        # mlp for prefix
        self.mlp_for_prefix = False

        if config.add_prefix:
            self.freeze_prefix = config.freeze_prefix
            self.prefix_ID = config.special_tokens["[prefix]"]
            self.prefix_no_pos_emb = config.prefix_no_pos_emb if hasattr(config, "prefix_no_pos_emb") else None
            self.prefix_embedding = nn.Parameter(torch.FloatTensor(config.num_prefix,config.hidden_size).squeeze())
            nn.init.normal_(self.prefix_embedding, mean=0, std=math.sqrt(1 / config.hidden_size))

            self.prefix_type_embedding = nn.Parameter(torch.FloatTensor(1, config.hidden_size).squeeze())
            nn.init.normal_(self.prefix_type_embedding, mean=0, std=math.sqrt(1 / config.hidden_size))

            if config.mlp_for_prefix:
                self.mlp_for_prefix = config.mlp_for_prefix

                self.prefix_mlp = nn.Sequential(
                    nn.Linear(config.hidden_size, int(config.hidden_size/2)),
                    nn.GELU(),# nn.Tanh(),
                    nn.Linear(int(config.hidden_size/2), config.hidden_size),
                )

            self.prefix_drop_prob = config.prefix_drop_prob if hasattr(config, "prefix_drop_prob") else 0.0
            self.prefix_shuffle_prob = config.prefix_shuffle_prob if hasattr(config, "prefix_shuffle_prob") else 0.0

            self.prefix_drop = prefix_drop(config.prefix_drop_prob)
            self.prefix_shuffle = prefix_shuffle(config.prefix_shuffle_prob)
            self.prefix_layernorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def freezing_prefix(self):
        if self.freeze_prefix is True:
            self.prefix_embedding.requires_grad_(not self.freeze_prefix)
            self.prefix_embedding.eval()

            if self.prefix_mlp is not None:
                self.prefix_mlp.requires_grad_(not self.freeze_prefix)
                self.prefix_mlp.eval()

                self.prefix_dropout.requires_grad_(not self.freeze_prefix)
                self.prefix_dropout.eval()

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        words_embeddings, position_embeddings, token_type_embeddings = \
            self.normal_embeddings(input_ids, token_type_ids, position_ids)

        if self.prefix_embedding is not None:
            # When all the other word embeddings are freezed,
            # and only prefix_embedding need to be trained.
            prefix_indices = torch.eq(input_ids, self.prefix_ID)

            num_prefixs = prefix_indices.int().sum()
            num_prefixs_from_initial_input = torch.logical_and(prefix_indices, token_type_ids).int().sum()

            if num_prefixs != 0:
                # Ensure if it is not from the generation process
                prefix = self.prefix_embedding.repeat(words_embeddings.size(0),1).to(torch.float32)
                if self.mlp_for_prefix:
                    prefix = self.prefix_mlp(prefix)
                    prefix = self.prefix_layernorm(prefix)

                prefix = self.prefix_drop(prefix)
                prefix = self.prefix_shuffle(prefix)

                # check if the prefix is generated from last iteration.
                if num_prefixs == num_prefixs_from_initial_input:
                    words_embeddings[prefix_indices] = prefix

            # - Compute Position Embedding Layer
            if self.prefix_no_pos_emb == 'no_pos':
                position_embeddings[prefix_indices] = 0
            elif self.prefix_no_pos_emb == 'avg_pos':
                position_embeddings[prefix_indices] = position_embeddings.mean(dim=-2)[0, :] # [0,:]

            token_type_embeddings[prefix_indices] = self.prefix_type_embedding

        # Sum of all embeddings
        embeddings = words_embeddings + position_embeddings + token_type_embeddings

        # # Layer Norm and dropout
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

    def normal_embeddings(self, input_ids, token_type_ids=None, position_ids=None):
        # Word embedding
        words_embeddings = self.word_embeddings(input_ids)

        # Position embedding
        # - Make position id if it is none.
        if position_ids is None:
            seq_length = input_ids.size(1)
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        # - Compute Position Embedding Layer
        position_embeddings = self.position_embeddings(position_ids)

        # Token type embedding
        # - Make token_type_id if it is none.
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        # - Compute Token Type Embedding Layer
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        return words_embeddings, position_embeddings, token_type_embeddings