# -*- coding: utf-8 -*-

# from UNIVERSAL.basic_optimizer import learning_rate_op, optimizer_op
import tensorflow as tf
from UNIVERSAL.block import UniversalTransformerBlock
from UNIVERSAL.model import transformer
from UNIVERSAL.utils import padding_util, maskAndBias_util, cka
from UNIVERSAL.basic_layer import embedding_layer, layerNormalization_layer
import json
import sys


class UniversalTransformer(transformer.Transformer):
    def __init__(
        self,
        vocabulary_size=40000,
        embedding_size=512,
        batch_size=64,
        num_units=512,
        num_heads=8,
        num_encoder_steps=6,
        num_decoder_steps=6,
        dropout=0.1,
        max_seq_len=60,
        src_sos_id=1,
        tgt_sos_id=1,
        src_eos_id=2,
        tgt_eos_id=2,
        pad_id=0,
        mask_id=3,
        unk_id=4,
        label_smoothing=0.1,
        epsilon=1e-9,
        preNorm=False,
        **kwargs
    ):
        super(UniversalTransformer, self).__init__(
            vocabulary_size=vocabulary_size,
            embedding_size=embedding_size,
            batch_size=batch_size,
            num_units=num_units,
            num_heads=num_heads,
            num_encoder_layers=0,
            num_decoder_layers=0,
            dropout=dropout,
            max_seq_len=max_seq_len,
            src_sos_id=src_sos_id,
            tgt_sos_id=tgt_sos_id,
            src_eos_id=src_eos_id,
            tgt_eos_id=tgt_eos_id,
            pad_id=pad_id,
            mask_id=mask_id,
            unk_id=unk_id,
            label_smoothing=label_smoothing,
        )
        super(transformer.Transformer, self).__init__(vocabulary_size, trimer=2)
        # setting NaiveSeq2Seq_model.##
        self.UT_encoder = UniversalTransformerBlock.UniversalTransformerEncoderBLOCK(
            num_units, num_heads, dropout, preNorm=preNorm, epsilon=epsilon
        )
        self.UT_decoder = UniversalTransformerBlock.UniversalTransformerDecoderBLOCK(
            num_units, num_heads, dropout, preNorm=preNorm, epsilon=epsilon
        )
        self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
            self.vocabulary_size, self.num_units, pad_id=self.PAD_ID, name="word_embedding", affine=False, scale_we=True
        )
        ####### for static controlling steps in training  ############
        self.num_encoder_steps = num_encoder_steps
        self.num_decoder_steps = num_decoder_steps
        ####### for dynamical controlling steps in inferring###
        self.dynamic_enc = self.num_encoder_steps
        self.dynamic_dec = self.num_decoder_steps

        # reimplement output layer
        self.probability_generator = tf.keras.layers.Dense(self.vocabulary_size, use_bias=False)
        self.preNorm = preNorm
        if preNorm:
            self.encoding_norm = layerNormalization_layer.LayerNorm()
            self.decoding_norm = layerNormalization_layer.LayerNorm()
        self.temp = {}

    def encoding(self, inputs, attention_bias=0, training=False, enc_position=None, vis=False):
        src = self.embedding_softmax_layer(inputs)
        if vis:
            orgData = tf.zeros([tf.shape(src)[0], tf.shape(src)[1], tf.shape(src)[2], 0])
            temp = tf.zeros([tf.shape(src)[1], 0])
            sentence = tf.zeros([0])
            pre = src
        with tf.name_scope("encoding"):
            if training:
                src = tf.nn.dropout(src, rate=self.dropout)
            for step in range(self.dynamic_enc):
                src = self.UT_encoder(
                    src,
                    attention_bias=attention_bias,
                    training=training,
                    step=step,
                    max_step=self.dynamic_enc,
                    max_seq=self.max_seq_len,
                )
                if self.preNorm:
                    src = self.encoding_norm(src)
                step += 1
                if vis:
                    temp = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, src), 0), temp], -1)
                    sentence = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, src, True), 0), sentence], -1)
                    orgData = tf.concat([tf.expand_dims(pre, -1), orgData], -1)
                    pre = src

        if vis:
            with open("./enc_cka_similarity.json", "w") as outfile:
                json.dump(temp.numpy().tolist(), outfile)
            with open("./enc_cka_similarity_sentence.json", "w") as outfile:
                json.dump(sentence.numpy().tolist(), outfile)
            orgData = tf.squeeze(cka.feature_space_linear_cka_3d_self(orgData), 0)
            with open("./enc_orgData.json", "w") as outfile:
                json.dump(orgData.numpy().tolist(), outfile)
        # return self.encoding_output(src)
        return src

    def decoding(
        self,
        inputs,
        enc,
        decoder_self_attention_bias,
        attention_bias,
        training=False,
        cache=None,
        decoder_padding=None,
        dec_position=None,
        vis=False,
    ):
        tgt = self.embedding_softmax_layer(inputs)
        if training:
            tgt = tf.nn.dropout(tgt, rate=self.dropout)
        if vis:
            orgData = tf.zeros([tf.shape(tgt)[0], tf.shape(tgt)[1], tf.shape(tgt)[2], 0])
            temp = tf.zeros([tf.shape(tgt)[1], 0])
            sentence = tf.zeros([0])
            pre = tgt
        with tf.name_scope("decoding"):
            for step in range(self.dynamic_dec):
                layer_name = "layer_%d" % step
                tgt = self.UT_decoder(
                    tgt,
                    enc,
                    decoder_self_attention_bias,
                    attention_bias,
                    training=training,
                    cache=cache[layer_name] if cache is not None else None,
                    decoder_padding=decoder_padding,
                    step=step,
                    dec_position=dec_position,
                    max_step=self.dynamic_dec,
                    max_seq=self.max_seq_len,
                )

                step += 1
                if self.preNorm:
                    tgt = self.decoding_norm(tgt)
                if vis:
                    temp = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, tgt), 0), temp], -1)
                    sentence = tf.concat([tf.squeeze(cka.feature_space_linear_cka(pre, tgt, True), 0), sentence], -1)
                    orgData = tf.concat([tf.expand_dims(pre, -1), orgData], -1)
                    pre = tgt
        if vis:
            with open("./dec_cka_similarity.json", "w") as outfile:
                json.dump(temp.numpy().tolist(), outfile)
            with open("./dec_cka_similarity_sentence.json", "w") as outfile:
                json.dump(sentence.numpy().tolist(), outfile)
            orgData = tf.squeeze(cka.feature_space_linear_cka_3d_self(orgData))
            with open("./dec_orgData.json", "w") as outfile:
                json.dump(orgData.numpy().tolist(), outfile)
        # return self.decoding_output(tgt)
        return tgt

    def forward(
        self,
        src,
        tgt,
        training=True,
        attention_bias=0,
        decoder_self_attention_bias=0,
        cache=None,
        decoder_padding=0,
        enc_position=None,
        dec_position=None,
        vis=False,
    ):

        enc = self.encoding(src, attention_bias, training=training, enc_position=enc_position, vis=vis)
        dec = self.decoding(
            tgt,
            enc,
            decoder_self_attention_bias,
            attention_bias,
            training=training,
            cache=cache,
            dec_position=dec_position,
            vis=vis,
        )
        # logits = self.embedding_softmax_layer(dec, linear=True)
        logits = self.probability_generator(dec)
        return logits

    def train_step(self, data):
        """
            return attention_bias, decoder_self_attention_bias, decoder_padding
        """
        ((x, y),) = data
        de_real_y = tf.pad(y, [[0, 0], [1, 0]], constant_values=1)[:, :-1]
        _ = self.seq2seq_training(super(UniversalTransformer, self).call, x, de_real_y, y, training=True)
        return {m.name: m.result() for m in self.metrics}

    def prepare_cache(self, src, src_id, sos_id):
        batch_size = tf.shape(src)[0]
        initial_ids = tf.zeros([batch_size], dtype=tf.int32) + sos_id
        attention_bias = padding_util.get_padding_bias(src)
        enc = self.encoding(src, attention_bias=attention_bias, training=False,)

        init_decode_length = 0
        dim_per_head = self.num_units // self.num_heads
        cache = dict()
        cache = {
            "layer_%d"
            % layer: {
                "k": tf.zeros([batch_size, self.num_heads, init_decode_length, dim_per_head]),
                "v": tf.zeros([batch_size, self.num_heads, init_decode_length, dim_per_head]),
            }
            # for layer in range(self.num_decoder_steps)
            for layer in range(self.dynamic_dec)
        }
        cache["decoder_padding"] = padding_util.get_padding_bias(tf.ones([batch_size, 1], dtype=tf.int32))
        cache["enc"] = enc
        cache["initial_ids"] = initial_ids
        cache["attention_bias"] = attention_bias
        return cache, batch_size

    def autoregressive_fn(self, max_decode_length, lang_embedding=0, tgt_domain_id=0):
        """Returns a decoding function that calculates logits of the next tokens."""
        decoder_self_attention_bias = maskAndBias_util.get_decoder_self_attention_bias(max_decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            decoder_input_id = tf.cast(ids[:, -1:], tf.int32)
            # decoder_input = (self.embedding_softmax_layer(decoder_input_id) +
            #                  lang_embedding)
            self_attention_bias = decoder_self_attention_bias[:, :, i : i + 1, : i + 1]
            dec = self.decoding(
                decoder_input_id,
                cache.get("enc"),
                decoder_self_attention_bias=self_attention_bias,
                attention_bias=cache.get("attention_bias"),
                dec_position=i,
                decoder_padding=cache.get("decoder_padding"),
                training=False,
                cache=cache,
            )
            # if tgt_domain_id is 0:
            # else:
            #     logits = self.embedding_softmax_layer._linear(
            #         dec, tgt_domain_id)
            # logits = self.embedding_softmax_layer._linear(dec)
            logits = self.probability_generator(dec)
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache

        return symbols_to_logits_fn

    def get_config(self):
        c = {
            "max_seq_len": self.max_seq_len,
            "vocabulary_size": self.vocabulary_size,
            "embedding_size": self.embedding_size,
            "batch_size": self.batch_size,
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "num_decoder_steps": self.num_decoder_steps,
            "num_encoder_steps": self.num_encoder_steps,
            "dropout": self.dropout,
        }
        return c

    def call(self, inputs, training=False, **kwargs):
        vis = False
        if "vis" in kwargs:
            vis = kwargs["vis"]
        if training:
            src, tgt = inputs[0], inputs[1]
            (attention_bias, decoder_self_attention_bias, decoder_padding,) = self.pre_processing(src, tgt)
            logits = self.forward(
                src,
                tgt,
                training=training,
                attention_bias=attention_bias,
                decoder_self_attention_bias=decoder_self_attention_bias,
                decoder_padding=decoder_padding,
                vis=vis,
            )
            return logits
        else:
            org_enc = self.dynamic_enc
            org_dec = self.dynamic_dec
            if "enc" in kwargs:
                self.dynamic_enc = kwargs["enc"]
            if "dec" in kwargs:
                self.dynamic_dec = kwargs["dec"]
            beam_szie = 0
            if "beam_size" in kwargs:
                beam_szie = kwargs["beam_size"]
            src = inputs
            max_decode_length = self.max_seq_len
            autoregressive_fn = self.autoregressive_fn(max_decode_length,)
            cache, batch_size = self.prepare_cache(src, self.TGT_SOS_ID, self.TGT_SOS_ID)
            re, score = self.predict(
                batch_size,
                autoregressive_fn,
                eos_id=self.TGT_EOS_ID,
                max_decode_length=self.max_seq_len,
                cache=cache,
                beam_size=beam_szie,
            )
            tf.print("enc", self.dynamic_enc, "dec", self.dynamic_dec, output_stream=sys.stdout)
            top_decoded_ids = re[:, 0, 1:]
            self.dynamic_dec = org_dec
            self.dynamic_enc = org_enc
            del cache,re,score
            return top_decoded_ids
