import os
import torch
import numpy as np
import scipy.special
import logging

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from models.gpt_generation import GPTGenerationFlyModel
from msc_lm_dataloader import DataLoaderHelper

logger = logging.getLogger(__name__)


def main():
    config = FlyConfig.load("configs/msc_config/gpt2_128.yaml")
    set_random_seed(config.training.random_seed)

    model = GPTGenerationFlyModel(config)

    print("Loading Data")
    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()
    test_dataloader = data_helper.test_loader_fn()

    print("Loading Model")
    
    model.configure_metrics()

    trainer = Trainer(config.training, model)


    with FlyLogger(config.flylogger) as flylogger:
        # save_config
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)

        logger.info(config)

        trainer.train(config.training, train_dataloader, valid_dataloader, test_dataloader)


if __name__ == "__main__":
    main()
