# -*- coding: utf-8 -*-
"""
    This is a naive seq2seq model wrapping commone functions.
"""
import tensorflow as tf
from UNIVERSAL.block import BeamSearchBlock
from UNIVERSAL.basic_metric import seq2seq_metric, mean_metric


class NaiveSeq2Seq(tf.keras.models.Model):
    def __init__(self, vocabulary_size, sos=1, eos=2,trimer=2, label_smothing=0.1):
        super(NaiveSeq2Seq, self).__init__(name="NaiveSeq2Seq")
        # try:
        #     self.gradient_tower = kwargs["gradient_tower"]
        # except Exception:
        #     self.gradient_tower = None
        # self.encoder = lambda x: x
        # self.decoder = lambda x: x
        self.beam_search = BeamSearchBlock.BeamSearch()
        self.eos = eos
        self.sos = sos
        self.label_smothing = label_smothing
        self.vocabulary_size = vocabulary_size

        # def build(self, input_shape):
        self.total_loss = mean_metric.Mean_MetricLayer("loss")
        self.grad_norm_ratio = mean_metric.Mean_MetricLayer("grad_norm_ratio")
        self.tokenPerS = mean_metric.Mean_MetricLayer("tokens/batch")
        self.finetune = False
        self.seq2seq_loss_FN = seq2seq_metric.CrossEntropy_layer(
            self.vocabulary_size,
            self.label_smothing,
            name="seq2seq_loss",
        )
        self.seq2seq_metric = seq2seq_metric.MetricLayer(trimer,
                                                         prefix="seq2seq")
        # super(NaiveSeq2Seq, self).build(input_shape)

        # if self.gradient_tower is not None:

    def seq2seq_training(self, call_fn, x, de_input_y, y, training=True, **kwargs):
        with tf.GradientTape() as model_tape:
            mt_logits = call_fn((x, de_input_y), training=training,  **kwargs)
            loss = self.seq2seq_loss_FN([y, mt_logits], auto_loss=False)
        model_gradients = model_tape.gradient(loss,
                                              self.trainable_variables)
        model_gradients, grad_norm = tf.clip_by_global_norm(
            model_gradients, 1.0)
        # if self.gradient_tower is not None:
        #     self.gradient_tower.tower_update(model_gradients,
        #                                      self.trainable_variables)
        # else:
        self.optimizer.apply_gradients(
            zip(model_gradients, self.trainable_variables))
        self.grad_norm_ratio(grad_norm)
        self.total_loss(loss)
        if 'tgt_label' in kwargs:
            y = kwargs["tgt_label"]
        self.seq2seq_metric([y, mt_logits])
        batch_size = tf.shape(x)[0]
        self.tokenPerS(
            tf.cast(
                tf.math.multiply(batch_size,
                                 (tf.shape(x)[1] + tf.shape(y)[1])),
                tf.float32))
        return loss

    def predict(self,
                batch_size,
                autoregressive_fn,
                sos_id=1,
                eos_id=2,
                cache=None,beam_size=0,max_decode_length=99):
        """Return predicted sequence."""
        decoded_ids, scores = self.beam_search.predict(batch_size,
                                                       autoregressive_fn,
                                                       self.vocabulary_size,
                                                       sos_id=sos_id,
                                                       eos_id=eos_id,
                                                       cache=cache,
                                                       max_decode_length=max_decode_length,beam_size=beam_size)
        return decoded_ids, scores
