import os
import torch
import numpy as np
import scipy.special
import logging

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from gpt_generation import GPTGenerationFlyModel
from memformers.recurrent_training.recurrent_trainer import RecurrentTrainer
from personachat_dataloader import DataLoaderHelper

logger = logging.getLogger(__name__)

def main():
    config = FlyConfig.load("configs/personachat_config/base_10%.yaml")
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    model = GPTGenerationFlyModel(config)
    trainer = Trainer(config.training, model)
    model.configure_metrics()


    with FlyLogger(config.flylogger) as flylogger:
        # save_config
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)
        
        trainer.train(config.training, train_dataloader, valid_dataloader)



if __name__ == "__main__":
    main()
