# -*- coding: utf-8 -*-
import tensorflow as tf
import os

cwd = os.getcwd()


def get_callbacks(model_path, LRschedule=None, save_freq=5000):
    TFboard = tf.keras.callbacks.TensorBoard(log_dir=cwd + "/model_summary/",
                                             write_images=True,
                                             histogram_freq=1000,
                                             embeddings_freq=1000,
                                             update_freq=500)
    TFchechpoint = tf.keras.callbacks.ModelCheckpoint(
        cwd + "/model_checkpoint" + "/model.{epoch:02d}-{loss:.2f}.ckpt",
        monitor="loss",
        save_weights_only=True,
        save_freq=5000,
        verbose=1,
    )
    NaNchecker = tf.keras.callbacks.TerminateOnNaN()
    call_backs = [
        # LRschedule,
        TFboard,
        TFchechpoint,
        NaNchecker,
    ]
    if LRschedule is not None:
        call_backs.append(LRschedule)
    return call_backs
