# -*- coding: utf-8 -*-
import tensorflow as tf


class LayerNorm(tf.keras.layers.Layer):
    """
        Layer normalization for transformer, we do that:
            ln(x) = α * (x - μ) / (σ**2 + ϵ)**0.5 + β
        mode:
            add: ln(x) + x
            norm: ln(x)
    """
    def __init__(
        self,
        epsilon=1e-9,
        gamma_initializer="ones",
        beta_initializer="zeros",
        mode="linear",
        name="norm",
    ):
        super(LayerNorm, self).__init__(name=name)
        self.epsilon = epsilon
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.mode = mode

    def build(self, input_shape):
        input_dim = input_shape[-1]
        if self.mode == "tanh":

            self.gamma_kernel = self.add_weight(
                shape=(input_dim),
                name="gamma",
                initializer=self.gamma_initializer)
            self.beta_kernel = self.add_weight(
                shape=(input_dim),
                name="beta",
                initializer=self.beta_initializer)
        else:
            self.norm_layer = tf.keras.layers.LayerNormalization(
                epsilon=self.epsilon)
        super(LayerNorm, self).build(input_shape)

    def call(self, inputs):
        # inputs = self.mask(inputs)
        # bias = hyper_util.zero_masking(inputs)
        if self.mode == "tanh":
            filter = tf.cast(tf.not_equal(inputs, 0.0), tf.float32)
            mean, variance = tf.nn.moments(inputs, [-1], keepdims=True)
            tanh_estimator = tf.cast(
                0.01 * ((inputs - mean) / (variance + self.epsilon)),
                tf.float32)
            normalized = 0.5 * (tf.nn.tanh(tanh_estimator) + 1.0) * filter
            output = self.gamma_kernel * normalized + self.beta_kernel
        else:
            # mean, variance = tf.nn.moments(inputs, [-1], keepdims=True)
            # normalized = (inputs - mean) / (tf.math.rsqrt(variance + self.epsilon))
            # output = self.gamma_kernel * normalized + self.beta_kernel
            output = self.norm_layer(inputs)
        return output

    def get_config(self):
        # config = super(LayerNorm, self).get_config()
        c = {"epsilon": self.epsilon}
        # config.update(c)
        return c


class NormBlock(tf.keras.layers.Layer):
    """Wrapper class that applies layer pre-processing and post-processing.
        pre-processing: x + layer(ln(x)))
        post-processing: ln(x + layer(x))

        NOTE that post-processing has a better performance on deep models.
    """
    def __init__(self, layer, dropout, pre_mode=True,epsilon=1e-9):
        super(NormBlock, self).__init__()
        self.layer = layer
        # self.num_units = num_units
        self.dropout = dropout
        self.pre_mode = pre_mode
        self.epsilon=epsilon

    def build(self, input_shape):
        # Create normalization layer
        self.layer_norm = LayerNorm(self.epsilon)
        super(NormBlock, self).build(input_shape)

    def get_config(self):
        return {"dropout": self.dropout, "add_mode": self.add_mode}

    def call(self, x, *args, **kwargs):
        """Calls wrapped layer with same parameters."""
        # Preprocessing: apply layer normalization
        training = kwargs["training"]
        if self.pre_mode:

            y = self.layer_norm(x)

            # Get layer output
            y = self.layer(y, *args, **kwargs)
            # if isinstance(y,tuple):
            #     y = y[0]
            # Postprocessing: apply dropout and residual connection
            if training:
                #     y = tf.nn.dropout(y, rate=self.dropout)
                y_shape = tf.shape(y)
                y = tf.nn.dropout(y,
                                  self.dropout,
                                  noise_shape=[y_shape[0], 1, y_shape[2]])

            y = y + x

        else:
            # Get layer output
            y = self.layer(x, *args, **kwargs)
            # if isinstance(y,tuple):
            #     y = y[0]
            # Postprocessing: apply dropout and residual connection
            if training:
                y = tf.nn.dropout(y, rate=self.dropout)
            y = self.layer_norm(y + x)
        return y
