# -*- coding: utf-8 -*-

import tensorflow as tf
from UNIVERSAL.basic_layer import embedding_layer, layerNormalization_layer
from UNIVERSAL.block import TransformerBlock
from UNIVERSAL.utils import padding_util, maskAndBias_util, staticEmbedding_util
from UNIVERSAL.training_and_learning.NaiveSeq2Seq_learning import NaiveSeq2Seq


class Transformer(NaiveSeq2Seq):
    def __init__(
        self,
        vocabulary_size=40000,
        embedding_size=512,
        batch_size=64,
        num_units=512,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dropout=0.1,
        max_seq_len=99,
        src_sos_id=1,
        tgt_sos_id=1,
        src_eos_id=2,
        tgt_eos_id=2,
        pad_id=0,
        mask_id=3,
        unk_id=4,
        epsilon=1e-9,
        label_smoothing=0.1,
        ffn_activation="relu",
        **kwargs
    ):
        NaiveSeq2Seq.__init__(self, vocabulary_size)

        self.TransformerEncoderBLOCK_stacked = [
            TransformerBlock.TransformerEncoderBLOCK(
                num_units=num_units,
                num_heads=num_heads,
                dropout=dropout,
                preNorm=True,
                epsilon=epsilon,
                ffn_activation=ffn_activation,
            )
            for _ in range(num_encoder_layers)
        ]
        self.TransformerDecoderBLOCK_stacked = [
            TransformerBlock.TransformerDecoderBLOCK(
                num_units=num_units,
                num_heads=num_heads,
                dropout=dropout,
                preNorm=True,
                epsilon=epsilon,
                ffn_activation=ffn_activation,
            )
            for _ in range(num_decoder_layers)
        ]

        ##
        self.max_seq_len = max_seq_len
        self.vocabulary_size = vocabulary_size
        self.embedding_size = embedding_size
        self.batch_size = batch_size
        self.num_units = num_units
        self.num_heads = num_heads
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.dropout = dropout
        self.SRC_SOS_ID = src_sos_id
        self.TGT_SOS_ID = tgt_sos_id
        self.SRC_EOS_ID = src_eos_id
        self.TGT_EOS_ID = tgt_eos_id
        self.PAD_ID = pad_id
        self.MASK_ID = mask_id
        self.UNK_ID = unk_id
        self.epsilon = epsilon
        self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
            self.vocabulary_size, self.num_units, pad_id=self.PAD_ID, name="word_embedding", affine=False,
        )
        self.final_enc_norm = layerNormalization_layer.LayerNorm()
        self.final_dec_norm = layerNormalization_layer.LayerNorm()
        ## setting NaiveSeq2Seq_model.##

    def get_config(self):
        c = {
            "max_seq_len": self.max_seq_len,
            "vocabulary_size": self.vocabulary_size,
            "embedding_size": self.embedding_size,
            "batch_size": self.batch_size,
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_decoder_layers": self.num_decoder_layers,
            "num_encoder_layers": self.num_encoder_layers,
            "dropout": self.dropout,
        }
        return c

    def encoding(
        self, inputs, attention_bias=0, training=False, enc_position=None,
    ):
        src = self.embedding_softmax_layer(inputs)
        with tf.name_scope("encoding"):
            if enc_position is not None:
                src = staticEmbedding_util.add_position_timing_signal(src, 0, 1, position=enc_position)
            else:
                src = staticEmbedding_util.add_position_timing_signal(src, 0, 1)
            if training:
                src = tf.nn.dropout(src, rate=self.dropout)
            for index, layer in enumerate(self.TransformerEncoderBLOCK_stacked):
                src = layer(src, attention_bias=attention_bias, training=training, index=index,)
        return self.final_enc_norm(src)

    def decoding(
        self,
        inputs,
        enc,
        decoder_self_attention_bias,
        attention_bias,
        training=False,
        cache=None,
        decoder_padding=None,
        dec_position=None,
    ):
        tgt = self.embedding_softmax_layer(inputs)
        with tf.name_scope("decoding"):
            if dec_position is not None:
                tgt = staticEmbedding_util.add_position_timing_signal(tgt, 0, 1, position=dec_position)
            else:
                tgt = staticEmbedding_util.add_position_timing_signal(tgt, 0, 1)
            if training:
                tgt = tf.nn.dropout(tgt, rate=self.dropout)
            layer_cache = None
            for index, layer in enumerate(self.TransformerDecoderBLOCK_stacked):
                if cache is not None:
                    layer_name = "layer_%d" % index
                    layer_cache = cache[layer_name] if cache is not None else None
                tgt = layer(
                    tgt,
                    enc,
                    decoder_self_attention_bias,
                    attention_bias,
                    training=training,
                    cache=layer_cache,
                    decoder_padding=decoder_padding,
                    index=index,
                )
        return self.final_dec_norm(tgt)

    def forward(
        self,
        src,
        tgt,
        training=True,
        attention_bias=0,
        decoder_self_attention_bias=0,
        cache=None,
        decoder_padding=0,
        enc_position=None,
        dec_position=None,
    ):
        enc = self.encoding(src, attention_bias, training=training, enc_position=enc_position)
        dec = self.decoding(
            tgt,
            enc,
            decoder_self_attention_bias,
            attention_bias,
            training=training,
            cache=cache,
            dec_position=dec_position,
        )
        logits = self.embedding_softmax_layer(dec, linear=True)
        return logits

    # compile the Transformer
    def compile(self, optimizer):
        super(Transformer, self).compile()
        self.optimizer = optimizer

    # training entry

    def train_step(self, data):
        """
            return attention_bias, decoder_self_attention_bias, decoder_padding
        """
        ((x, y),) = data
        de_real_y = tf.pad(y, [[0, 0], [1, 0]], constant_values=self.TGT_SOS_ID)[:, :-1]
        _ = self.seq2seq_training(self.call, x, de_real_y, y, training=True)
        return {m.name: m.result() for m in self.metrics}

    def pre_processing(self, src, tgt):
        attention_bias = padding_util.get_padding_bias(src)
        decoder_padding = padding_util.get_padding_bias(tgt)
        decoder_self_attention_bias = maskAndBias_util.get_decoder_self_attention_bias(tf.shape(tgt)[1], tgt)
        return attention_bias, decoder_self_attention_bias, decoder_padding

    def call(self, inputs, training=False, **kwargs):
        vis = False
        if "vis" in kwargs:
            vis = kwargs["vis"]
        if training:
            src, tgt = inputs[0], inputs[1]
            (attention_bias, decoder_self_attention_bias, decoder_padding,) = self.pre_processing(src, tgt)
            logits = self.forward(
                src,
                tgt,
                training=training,
                attention_bias=attention_bias,
                decoder_self_attention_bias=decoder_self_attention_bias,
                decoder_padding=decoder_padding,
                vis=vis,
            )
            return logits
        else:
            beam_szie = 0
            if "beam_size" in kwargs:
                beam_szie = kwargs["beam_size"]
            src = inputs
            max_decode_length = self.max_seq_len
            autoregressive_fn = self.autoregressive_fn(max_decode_length,)
            cache, batch_size = self.prepare_cache(src, self.TGT_SOS_ID, self.TGT_SOS_ID)
            re, score = self.predict(
                batch_size, autoregressive_fn, eos_id=self.TGT_EOS_ID, cache=cache, beam_size=beam_szie
            )
            top_decoded_ids = re[:, 0, 1:]

            return top_decoded_ids

    def prepare_cache(self, src, src_id, sos_id):
        batch_size = tf.shape(src)[0]
        initial_ids = tf.ones([batch_size], dtype=tf.int32) * sos_id
        attention_bias = padding_util.get_padding_bias(src)
        enc = self.encoding(src, attention_bias=attention_bias, training=False,)

        init_decode_length = 0
        dim_per_head = self.num_units // self.num_heads
        cache = dict()
        cache.update(
            {
                "layer_%d"
                % layer: {
                    "k": tf.zeros([batch_size, self.num_heads, init_decode_length, dim_per_head]),
                    "v": tf.zeros([batch_size, self.num_heads, init_decode_length, dim_per_head]),
                }
                for layer in range(self.num_decoder_steps)
            }
        )
        cache["decoder_padding"] = padding_util.get_padding_bias(tf.ones([batch_size, 1], dtype=tf.int32))
        cache["enc"] = enc
        cache["initial_ids"] = initial_ids
        cache["attention_bias"] = attention_bias
        return cache, batch_size

    def autoregressive_fn(self, max_decode_length, lang_embedding=0, tgt_domain_id=0):
        """Returns a decoding function that calculates logits of the next tokens."""
        decoder_self_attention_bias = maskAndBias_util.get_decoder_self_attention_bias(max_decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            decoder_input_id = tf.cast(ids[:, -1:], tf.int32)
            # decoder_input = (self.embedding_softmax_layer(decoder_input_id) +
            #                  lang_embedding)
            self_attention_bias = decoder_self_attention_bias[:, :, i : i + 1, : i + 1]
            dec = self.decoding(
                decoder_input_id,
                cache.get("enc"),
                decoder_self_attention_bias=self_attention_bias,
                attention_bias=cache.get("attention_bias"),
                dec_position=i,
                decoder_padding=cache.get("decoder_padding"),
                training=False,
                cache=cache,
            )
            if tgt_domain_id is 0:
                logits = self.embedding_softmax_layer._linear(dec)
            else:
                logits = self.embedding_softmax_layer._linear(dec, tgt_domain_id)
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache

        return symbols_to_logits_fn
