# coding: utf-8 -*-
# code warrior: Barid
import tensorflow as tf
from UNIVERSAL.basic_layer import layerNormalization_layer
from UNIVERSAL.utils import cka


class LT(tf.keras.layers.Layer):
    def __init__(self, layer, dropout):
        """
    Args:
      num_units: int, output dim of hidden layer.
      attention_dropout: float, dropout rate inside attention for training.
    """
        super(LT, self).__init__()
        self.layer = layer
        self.dropout = dropout

    def build(self, input_shape):
        """Builds the layer."""
        self.num_units = input_shape[-1]
        self.output_norm = layerNormalization_layer.LayerNorm()
        self.halting_pro = 0
        super(LT, self).build(input_shape)

    def call(self, x, *args, **kwargs):
        training = kwargs["training"]
        #### self.layer is a UT block###
        y = self.layer(x, *args, **kwargs)
        self.halting_pro = tf.stop_gradient(cka.feature_space_linear_cka(x, y))
        if training:
            y = tf.nn.dropout(y, self.dropout)
        y = self.output_norm(x + (1 - self.halting_pro) * y)
        return y
