# -*- coding: utf-8 -*-
# code warrior: Barid

# from UNIVERSAL.basic_optimizer import learning_rate_op, optimizer_op
import tensorflow as tf
from UNIVERSAL.block import UniversalTransformerBlock
from UNIVERSAL.model import ut
from UNIVERSAL.utils import padding_util, cka
from UNIVERSAL.basic_layer import embedding_layer, layerNormalization_layer
import lazyTransition
import json


class UniversalTransformer(ut.UniversalTransformer):
    def __init__(
        self,
        vocabulary_size=40000,
        embedding_size=512,
        batch_size=64,
        num_units=512,
        num_heads=8,
        num_encoder_steps=6,
        num_decoder_steps=6,
        dropout=0.1,
        max_seq_len=60,
        src_sos_id=1,
        tgt_sos_id=1,
        src_eos_id=2,
        tgt_eos_id=2,
        pad_id=0,
        mask_id=3,
        unk_id=4,
        label_smoothing=0.1,
        epsilon=1e-9,
        preNorm=False,
        **kwargs
    ):
        super(UniversalTransformer, self).__init__(
            vocabulary_size=vocabulary_size,
            embedding_size=embedding_size,
            batch_size=batch_size,
            num_units=num_units,
            num_heads=num_heads,
            num_encoder_layers=1,
            num_decoder_layers=1,
            dropout=dropout,
            max_seq_len=max_seq_len,
            src_sos_id=src_sos_id,
            tgt_sos_id=tgt_sos_id,
            src_eos_id=src_eos_id,
            tgt_eos_id=tgt_eos_id,
            pad_id=pad_id,
            mask_id=mask_id,
            unk_id=unk_id,
            preNorm=preNorm,
            epsilon=epsilon,
            label_smoothing=label_smoothing,
        )
        super(ut.UniversalTransformer, self).__init__(vocabulary_size, trimer=2)
        self.UT_encoder = lazyTransition.LT(
            UniversalTransformerBlock.UniversalTransformerEncoderBLOCK(
                num_units, num_heads, dropout, preNorm=preNorm, epsilon=epsilon
            ),
            dropout,
        )
        self.UT_decoder = lazyTransition.LT(
            UniversalTransformerBlock.UniversalTransformerDecoderBLOCK(
                num_units, num_heads, dropout, preNorm=preNorm, epsilon=epsilon
            ),
            dropout,
        )
        ####### for static controlling steps in training  ############
        self.num_encoder_steps = num_encoder_steps
        self.num_decoder_steps = num_decoder_steps
        ####### for dynamical controlling steps in inferring###
        self.dynamic_enc = self.num_encoder_steps
        self.dynamic_dec = self.num_decoder_steps

        self.dynamic_halting = 1

    def encoding(self, inputs, attention_bias=0, training=False, enc_position=None, vis=False):
        src = self.embedding_softmax_layer(inputs)
        pre = src
        if vis:
            orgData = tf.zeros([tf.shape(src)[0], tf.shape(src)[1], tf.shape(src)[2], 0])
            temp = tf.zeros([tf.shape(src)[1], 0])
            sentence = tf.zeros([0])
            halting = tf.zeros([tf.shape(src)[1], 0])
        with tf.name_scope("encoding"):
            if training:
                src = tf.nn.dropout(src, rate=self.dropout)
            mask = tf.zeros([tf.shape(src)[0], tf.shape(src)[1], 1])
            for step in range(self.dynamic_enc):
                src = self.UT_encoder(
                    src,
                    attention_bias=attention_bias,
                    training=training,
                    step=step,
                    max_step=self.num_encoder_steps,
                    max_seq=self.max_seq_len,
                    step_encoding=False,
                )
                step += 1
                if vis:
                    temp = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, src), 0), temp], -1)
                    sentence = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, src, True), 0), sentence], -1)
                    orgData = tf.concat([tf.expand_dims(pre, -1), orgData], -1)
                    halting = tf.concat([self.UT_encoder.halting_pro[0], halting], -1)
                if not training:
                    mask = tf.maximum(
                        mask, tf.cast(tf.greater(self.UT_encoder.halting_pro, self.dynamic_halting), tf.float32)
                    )
                    src = pre * mask + src * (1 - mask)
                pre = src

        if vis:
            with open("./enc_cka_similarity.json", "w") as outfile:
                json.dump(temp.numpy().tolist(), outfile)
            with open("./enc_cka_similarity_sentence.json", "w") as outfile:
                json.dump(sentence.numpy().tolist(), outfile)
            orgData = tf.squeeze(cka.feature_space_linear_cka_3d_self(orgData), 0)
            with open("./enc_orgData.json", "w") as outfile:
                json.dump(orgData.numpy().tolist(), outfile)
            with open("./enc_halting_pro.json", "w") as outfile:
                json.dump(halting.numpy().tolist(), outfile)
        return src

    def decoding(
        self,
        inputs,
        enc,
        decoder_self_attention_bias,
        attention_bias,
        training=False,
        cache=None,
        decoder_padding=None,
        dec_position=None,
        vis=False,
    ):
        tgt = self.embedding_softmax_layer(inputs)
        pre = tgt
        if training:
            tgt = tf.nn.dropout(tgt, rate=self.dropout)
        if vis:
            orgData = tf.zeros([tf.shape(tgt)[0], tf.shape(tgt)[1], tf.shape(tgt)[2], 0])
            temp = tf.zeros([tf.shape(tgt)[1], 0])
            halting = tf.zeros([tf.shape(tgt)[1], 0])
            sentence = tf.zeros([0])
        with tf.name_scope("decoding"):
            mask = tf.zeros([tf.shape(tgt)[0], tf.shape(tgt)[1], 1])
            for step in range(self.dynamic_dec):
                layer_name = "layer_%d" % step
                tgt = self.UT_decoder(
                    tgt,
                    enc,
                    decoder_self_attention_bias,
                    attention_bias,
                    training=training,
                    cache=cache[layer_name] if cache is not None else None,
                    decoder_padding=decoder_padding,
                    step=step,
                    dec_position=dec_position,
                    max_step=self.num_decoder_steps,
                    max_seq=self.max_seq_len,
                    step_encoding=False,
                )

                step += 1
                if vis:
                    temp = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, tgt), 0), temp], -1)
                    sentence = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, tgt, True), 0), sentence], -1)
                    orgData = tf.concat([tf.expand_dims(pre, -1), orgData], -1)
                    halting = tf.concat([self.UT_decoder.halting_pro[0], halting], -1)
                if not training:
                    mask = tf.maximum(
                        mask, tf.cast(tf.greater(self.UT_decoder.halting_pro, self.dynamic_halting), tf.float32)
                    )
                    tgt = pre * mask + tgt * (1 - mask)
                pre = tgt
        if vis:
            with open("./dec_cka_similarity.json", "w") as outfile:
                json.dump(temp.numpy().tolist(), outfile)
            with open("./dec_cka_similarity_sentence.json", "w") as outfile:
                json.dump(sentence.numpy().tolist(), outfile)
            orgData = tf.squeeze(cka.feature_space_linear_cka_3d_self(orgData))
            with open("./dec_orgData.json", "w") as outfile:
                json.dump(orgData.numpy().tolist(), outfile)

            with open("./dec_halting_pro.json", "w") as outfile:
                json.dump(halting.numpy().tolist(), outfile)
        return tgt
