from model.gpt2_unlearn import GPT2
from model.gpt2_dp import GPT2_DP
from model.gpt2_valid import GPT2_Valid
from pytorch_lightning.loggers import WandbLogger
from argparse import ArgumentParser
from utils import MetricTracker
import pytorch_lightning as pl
import logging
import argparse
import yaml

if __name__ == '__main__':
    # Parsing Arguments
    parser = ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    arg_ = parser.parse_args()
    if arg_.config is None:
        raise NameError("Include a config file in the argument please.")

    # Getting configurations
    config_path = arg_.config
    config = yaml.full_load(open(config_path, 'r'))

    # Init configs
    if config['common']['seed']:
        seed = config['common']['seed']

    # set random seed
    pl.seed_everything(config['common']['seed'], workers=True)

    # Set console logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '[%(levelname)s] %(asctime)s (%(filename)s:%(lineno)d) : %(message)s'
    )
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    # Set wandb logger
    wandb_param = config['wandb']
    if config['wandb']:
        wandb_logger = WandbLogger(
            project=wandb_param['wandb_project'],
            name=wandb_param['wandb_run_name'],
            entity='fengxiaohualala'
        )
    else:
        wandb_logger = None

    # set callbacks
    callbacks = [MetricTracker(run_project = wandb_param['wandb_project'], run_name = wandb_param['wandb_run_name'])]

    # Setting for pytorch lightning trainer
    trainer_param = config['trainer']
    train_params = dict(
        accumulate_grad_batches=trainer_param['gradient_accumulation_steps'],
        accelerator='gpu',
        devices=trainer_param['ngpu'],
        max_epochs=int(trainer_param['num_train_epochs']),
        precision=16 if trainer_param['fp16'] else 32,
        check_val_every_n_epoch=trainer_param['check_val_every_n_epoch'],
        enable_checkpointing=False,
        callbacks=callbacks,
        logger=wandb_logger,
        strategy=trainer_param['strategy'],
        val_check_interval=1,
        num_sanity_val_steps=trainer_param['num_sanity_val_steps'],
        limit_val_batches=trainer_param['limit_val_batches'],
        log_every_n_steps=1,
    )

    # start training
    if config['common']['check_validation_only']:
        # init trainer
        trainer = pl.Trainer(**train_params)
        if 'dp' in config['model']['privacy_type']:
            model = GPT2_DP(config)
        else:
            model = GPT2_Valid(config)
        trainer.validate(model)
    else:
        trainer = pl.Trainer(**train_params)
        model = GPT2(config)
        if config['common']['do_init_eval']:
            trainer.validate(model)
        trainer.fit(model)
