import os
import copy
import pickle
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BartTokenizer

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from bart_generation import BartGenerationFlyModel
from persuasion_dataloader import DataLoaderHelper

config_path = "configs/persuasion_config/large.yaml"


def main():
    # we recommand adding this function before everything starts
    if "RANK" in os.environ:
        torch.distributed.init_process_group(
            backend='nccl', init_method='env://')

    config = FlyConfig.load(config_path)
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    model = BartGenerationFlyModel(config)
    trainer = Trainer(config.training, model)

    model.configure_metrics()

    with FlyLogger(config.flylogger) as flylogger:
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)

        trainer.train(config.training, train_dataloader, valid_dataloader)

if __name__ == "__main__":
    main()
