import os
import torch
import numpy as np
import scipy.special
import logging

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from personachat_dataloader import DataLoaderHelper


logger = logging.getLogger(__name__)


def main():
    config = FlyConfig.load("configs/personachat_config/bart_base_64.yaml")
    set_random_seed(config.training.random_seed)

    if "bart" in config.task.method:
        from models.bart_generation import BartGenerationFlyModel

        model = BartGenerationFlyModel(config)

    print("Loading Data")
    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    print("Loading Model")

    model.configure_metrics()

    state_dict = torch.load("../outputs/bart-base/personachat/64_1.0/Trainer1_Stage1/evaluation/model_weights/epoch_2_step_8529.pth")["model_weights"]
    model.load_state_dict(state_dict)

    trainer = Trainer(config.training, model)

    with FlyLogger(config.flylogger) as flylogger:
        # save_config
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)

        logger.info(config)

        trainer.train(config.training, train_dataloader, valid_dataloader)


if __name__ == "__main__":
    main()
