# -*- coding: utf-8 -*-
import tensorflow as tf
from UNIVERSAL.basic_layer import attention_layer, ffn_layer, layerNormalization_layer
from UNIVERSAL.utils import staticEmbedding_util

SRC_LANG = 1
TGT_LANG = 2


class TransformerEncoderBLOCK(tf.keras.layers.Layer):
    """
        Navie TransformerEncoderBLOCK implementation including:
        1. 1-layer self attention
        2. 1-layer FFN
    """

    def __init__(
        self,
        num_units=512,
        num_heads=8,
        dropout=0.1,
        preNorm=True,
        epsilon=1e-9,
        name="TransformerEncoderBLOCK",
        ffn_activation="relu",
    ):
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout = dropout
        self.n = name
        self.preNorm = preNorm
        self.ffn_activation = ffn_activation
        self.epsilon = epsilon
        super(TransformerEncoderBLOCK, self).__init__(name=name)

    def build(self, input_shape):
        self_attention = attention_layer.Attention(
            num_heads=self.num_heads, num_units=self.num_units, dropout=self.dropout,
        )

        ffn = ffn_layer.Feed_Forward_Network(
            num_units=4 * self.num_units, activation_filter=self.ffn_activation, dropout=self.dropout
        )
        self.self_att = layerNormalization_layer.NormBlock(
            self_attention, self.dropout, pre_mode=self.preNorm, epsilon=self.epsilon
        )
        self.ffn = layerNormalization_layer.NormBlock(ffn, self.dropout, pre_mode=self.preNorm, epsilon=self.epsilon)
        if self.preNorm:
            self.final_norm = layerNormalization_layer.LayerNorm()
        super(TransformerEncoderBLOCK, self).build(input_shape)

    def call(self, inputs, attention_bias=0, training=False, index=None, scale=None, **kwargs):
        inputs = self.self_att(inputs, inputs, attention_bias, training=training, scale=scale)
        inputs = self.ffn(inputs, training=training, padding_position=attention_bias)
        # if self.preNorm:
        #     inputs = self.final_norm(inputs)
        return inputs

    def get_config(self):
        c = {
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_encoder_layers": self.num_encoder_layers,
            "dropout": self.dropout,
        }
        return c


class TransformerDecoderBLOCK(tf.keras.layers.Layer):
    def __init__(
        self,
        num_units=512,
        num_heads=8,
        dropout=0.1,
        preNorm=True,
        epsilon=1e-9,
        name="TransformerDecoderBLOCK",
        ffn_activation="relu",
    ):
        super(TransformerDecoderBLOCK, self).__init__(name=name)
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout = dropout
        self.attention_weights = dict()
        self.n = name
        self.ffn_activation = ffn_activation
        self.preNorm = preNorm
        self.epsilon = epsilon

    def get_attention_weights(self):
        return self.attention_weights

    def build(self, input_shape):

        self_attention = attention_layer.Attention(
            num_heads=self.num_heads, num_units=self.num_units, dropout=self.dropout,
        )

        attention = attention_layer.Attention(num_heads=self.num_heads, num_units=self.num_units, dropout=self.dropout,)
        ffn = ffn_layer.Feed_Forward_Network(
            num_units=4 * self.num_units, activation_filter=self.ffn_activation, dropout=self.dropout
        )
        self.self_att = layerNormalization_layer.NormBlock(
            self_attention, self.dropout, pre_mode=self.preNorm, epsilon=self.epsilon
        )
        self.att = layerNormalization_layer.NormBlock(
            attention, self.dropout, pre_mode=self.preNorm, epsilon=self.epsilon
        )
        self.ffn = layerNormalization_layer.NormBlock(ffn, self.dropout, pre_mode=self.preNorm, epsilon=self.epsilon)
        # if self.preNorm:
        #     self.final_norm = layerNormalization_layer.LayerNorm()
        super(TransformerDecoderBLOCK, self).build(input_shape)

    def call(
        self,
        inputs,
        enc,
        decoder_self_attention_bias,
        attention_bias,
        training=False,
        cache=None,
        decoder_padding=None,
        index=None,
        scale=None,
        **kwargs
    ):
        with tf.name_scope(self.n):
            # if index is not None:
            #     layer_name = "layer_%d" % index
            #     layer_cache = cache[layer_name] if cache is not None else None
            inputs = self.self_att(
                inputs, inputs, decoder_self_attention_bias, training=training, cache=cache, scale=scale
            )
            inputs = self.att(inputs, enc, attention_bias, training=training, scale=scale)
            inputs = self.ffn(inputs, training=training, padding_position=attention_bias)

            # attn_weights_block1 = self.self_att.layer.get_attention_weights()
            # attn_weights_block2 = self.att.layer.get_attention_weights()
            # self.attention_weights["decoder_layer{}_block1".format(
            #     index + 1)] = attn_weights_block1
            # self.attention_weights["decoder_layer{}_block2".format(
            #     index + 1)] = attn_weights_block2
            # if self.preNorm:
            #     inputs = self.final_norm(inputs)

            # for autoregressive
            return inputs

    def get_config(self):
        c = {
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_decoder_layers": self.num_decoder_layers,
            "dropout": self.dropout,
        }
        return c
