# -*- coding: utf-8 -*-
import tensorflow as tf
from UNIVERSAL.utils import staticEmbedding_util
from UNIVERSAL.block import TransformerBlock
from UNIVERSAL.basic_layer import layerNormalization_layer, attention_layer, residual_layer, ffn_layer
import math

# import numpy as np


def ut_input_preprocess(
    inputs, step, training=False, position_index=None, step_encoding=True, position_encoding=True, **kwargs
):
    # obseving the model could not understand the distinguish Between
    # step position and position encoding becasu they have the same value.
    if "max_step" in kwargs:
        max_step = kwargs["max_step"]
    else:
        max_step = 50
    if "max_seq" in kwargs:
        max_seq = kwargs["max_seq"]
    else:
        max_seq = 1000
    if position_index is not None:
        length = max_seq
    else:
        length = None
    if step_encoding:
        inputs = staticEmbedding_util.add_step_timing_signal(inputs, step, max_step)
    if position_encoding:
        inputs = staticEmbedding_util.add_position_timing_signal(
            inputs, max_step + 1, position=position_index, length=length
        )
    return inputs


class UniversalTransformerEncoderBLOCK(TransformerBlock.TransformerEncoderBLOCK):
    def __init__(
        self,
        num_units=512,
        num_heads=8,
        dropout=0.1,
        preNorm=False,
        epsilon=1e-9,
        name="UniversalTransformerEncoderBLOCK",
    ):
        super(UniversalTransformerEncoderBLOCK, self).__init__(
            num_units=num_units, num_heads=num_heads, dropout=dropout, preNorm=preNorm, epsilon=epsilon, name=name
        )

    def call(
        self,
        inputs,
        attention_bias,
        enc_position=None,
        training=False,
        step=None,
        position_encoding=True,
        step_encoding=True,
        **kwargs
    ):
        with tf.name_scope(self.n):
            if step_encoding:
                assert step is not None, "step is required when setting step_encoding = True"
            with tf.name_scope("layer_%d" % step):
                inputs = ut_input_preprocess(
                    inputs,
                    step,
                    position_index=enc_position,
                    position_encoding=position_encoding,
                    step_encoding=step_encoding,
                    **kwargs
                )
                inputs = super(UniversalTransformerEncoderBLOCK, self).call(
                    inputs, attention_bias, training=training, index=step, scale=self.num_units ** -0.5, **kwargs
                )
        return inputs

    def get_config(self):
        c = {
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_encoder_steps": self.num_encoder_steps,
            "dropout": self.dropout,
        }
        return c


class UniversalTransformerDecoderBLOCK(TransformerBlock.TransformerDecoderBLOCK):
    def __init__(
        self,
        num_units=512,
        num_heads=8,
        dropout=0.1,
        preNorm=False,
        epsilon=1e-9,
        name="UniversalTransformerDecoderBLOCK",
    ):
        super(UniversalTransformerDecoderBLOCK, self).__init__(
            num_units=num_units, num_heads=num_heads, dropout=dropout, preNorm=preNorm, epsilon=epsilon, name=name
        )

    def call(
        self,
        inputs,
        enc,
        decoder_self_attention_bias,
        attention_bias,
        dec_position=None,
        training=False,
        cache=None,
        decoder_padding=None,
        step=None,
        position_encoding=True,
        step_encoding=True,
        **kwargs
    ):
        with tf.name_scope(self.n):
            if step_encoding:
                assert step is not None, "step is required when setting step_encoding = True"
            with tf.name_scope("layer_%d" % step):
                inputs = ut_input_preprocess(
                    inputs,
                    step,
                    position_index=dec_position,
                    position_encoding=position_encoding,
                    step_encoding=step_encoding,
                    **kwargs
                )
                inputs = super(UniversalTransformerDecoderBLOCK, self).call(
                    inputs,
                    enc,
                    decoder_self_attention_bias,
                    attention_bias,
                    training=training,
                    cache=cache,
                    decoder_padding=None,
                    index=step,
                    scale=self.num_units ** -0.5,
                    **kwargs
                )
                return inputs

    def get_config(self):
        c = {
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_decoder_steps": self.num_decoder_steps,
            "dropout": self.dropout,
        }
        return c
