# -*- coding: utf-8 -*-
# code warrior: Barid
##########
from UNIVERSAL.data_and_corpus import offline_corpus, data_manager, dataset_preprocessing
import numpy as np
import tensorflow as tf
import configuration
from UNIVERSAL.basic_optimizer import learning_rate_op
from UNIVERSAL.training_and_learning import callback_training
import lt
import os

cwd = os.getcwd()
#############
## src tgt##
offline = [
    [
        "./v2_train.en",
    ],
    [
        "./v2_train.de",
    ],
]


def EOS_entailment(src, tgt):
    def _encode(lang1, lang2):
        def __plusEOS(x):
            x_eos = np.concatenate((x.numpy(), [configuration.parameters["EOS_ID"]]), 0)
            return x_eos

        x_eos = __plusEOS(lang1)
        y_eos = __plusEOS(lang2)
        return x_eos, y_eos

    x_eos, y_eos = tf.py_function(_encode, [src, tgt], [tf.int32, tf.int32,])
    x_eos.set_shape([None])
    y_eos.set_shape([None])
    return (x_eos, y_eos)


def preprocessed_dataset(shuffle=40):
    training_samples = offline_corpus.offline(offline)
    dataManager = data_manager.DatasetManager(cwd + "/../UNIVERSAL/vocabulary/DeEn_32000/", training_samples)
    dataset = dataManager.get_raw_train_dataset()
    preprocessed_dataset = dataset_preprocessing.prepare_training_input(
        dataset,
        configuration.parameters["batch_size"],
        configuration.parameters["max_sequence_length"],
        min_boundary=8,
        filter_min=1,
        filter_max=configuration.parameters["max_sequence_length"],
        tf_encode=EOS_entailment,
        shuffle=shuffle,
    )
    return preprocessed_dataset, dataManager


def optimizer():
    # return optimizer_op.AdamWeightDecay(
    #     learning_rate=10e-4,
    #     # weight_decay_rate=0.001,
    #     exclude_from_weight_decay=["layer_norm", "bias"], )
    return tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.98, epsilon=1e-9)


def callbacks():
    lr_schedual = learning_rate_op.LearningRateScheduler(
        learning_rate_op.LearningRateFn_WL(
            hidden_size=configuration.parameters["num_units"], warmup_steps=configuration.parameters["learning_warmup"]
        ),
        0,
    )
    return callback_training.get_callbacks(cwd, lr_schedual)


def trainer():
    main_model = lt.UniversalTransformer(
        vocabulary_size=configuration.parameters["vocabulary_size"],
        embedding_size=configuration.parameters["embedding_size"],
        batch_size=configuration.parameters["batch_size"],
        num_units=configuration.parameters["num_units"],
        num_heads=configuration.parameters["num_heads"],
        num_encoder_steps=configuration.parameters["num_encoder_steps"],
        num_decoder_steps=configuration.parameters["num_decoder_steps"],
        dropout=configuration.parameters["dropout"],
        max_seq_len=configuration.parameters["max_sequence_length"],
    )
    return main_model
