import argparse
import os

from setproctitle import setproctitle
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
import torch

from module.TrainingModule import PlModelModule
from module.ConfigModule import ManualArgs, config_loading, print_args
from module.ModelingModule import return_model
from module.DataModule import DataModule
from utils.training_utils import logger


def train(parser):
    parser = ManualArgs().make_manual_config(save=None, parser=parser)
    args = parser.parse_args()
    args = config_loading(parser, config_manual=args.config_manual, config_trainer=args.config_trainer)

    args.gradient_clip_val = 1.0
    args.default_root_dir = 'logs'
    # args.gpus = 1

    print_args(args.__dict__, logger.info)

    os.makedirs(os.path.join(args.save_filename, 'checkpoints'), exist_ok=True)
    os.makedirs(os.path.join(args.save_filename, 'gen_files'), exist_ok=True)

    core_model, tokenizer = return_model(args)

    dm = DataModule(tokenizer, args)
    model = PlModelModule(tokenizer, core_model, args)

    wb_logger = pl_loggers.WandbLogger(project=args.wb_project, name=args.wb_name)
    lr_logger = pl.callbacks.LearningRateMonitor()

    trainer = pl.Trainer.from_argparse_args(
        args,
        logger=wb_logger, callbacks=[lr_logger],
        accumulate_grad_batches=args.accumulate_grad
    )
    if args.prev_model is not None:
        print('Prev model loading')
        print(model.load_state_dict(torch.load(args.prev_model)['model_state_dict']))

    trainer.fit(model, dm)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    train(parser)

