import os
import torch

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from memformers.models.bart_base.flymodel import BartBaseFlyModel
from dataloader import DataLoaderHelper

def main():
    config = FlyConfig.load("config/bart_base.yaml")
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config.data_processing, training_batch_size=config.training.batch_size)
    train_dataloader = data_helper.train_dataloader_fn()

    model = BartBaseFlyModel(config)
    model.configure_metrics()

    trainer = Trainer(config.training, model)

    with FlyLogger(config.flylogger) as flylogger:
        # save_config
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)

        trainer.train(config.training, train_dataloader)


if __name__ == "__main__":
    main()
