import yaml
import json
import argparse
from os.path import join as pjoin
from types import SimpleNamespace
import logging
from pprint import pprint

import pytorch_lightning as pl

class ManualArgs:
    @staticmethod
    def add_file_loading(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
        parser.add_argument('--save_filename', type=str, default=None, help='filename')
        parser.add_argument('--train_file', type=str, default="QAgen/train.csv")
        parser.add_argument('--valid_file', type=str, default="QAgen/validation.csv")
        parser.add_argument('--config_manual', type=str, default='config/config_manual.yaml')
        parser.add_argument('--config_trainer', type=str, default='config/config_trainer.yaml')
        parser.add_argument('--prev_model', type=str, default=None, help='')

        parser.add_argument('--wb_project', type=str, default=None, help='wandb project')
        parser.add_argument('--wb_name', type=str, default=None, help='wandb name')
        parser.add_argument('--proctitle', type=str, default=None)
        return parser

    @staticmethod
    def add_specific(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
        parser.add_argument('--model_type', type=str, default='bart')
        parser.add_argument('--training_type', type=int, default=1, help='')
        return parser

    @staticmethod
    def add_basic(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)

        # 훈련 하이퍼파라미터들
        parser.add_argument('--seed', type=int, default=None, help='seed')
        parser.add_argument('--ckpt_save_num', type=int, default=1, help='ckpt_save_num')
        parser.add_argument('--num_workers', type=int, default=5, help='num of worker for dataloader')
        parser.add_argument('--max_steps', type=int, default=-1)
        parser.add_argument('--max_epochs', type=int, default=50)
        parser.add_argument('--beam_size', type=int, default=1, help='beam_size')
        parser.add_argument('--batch_size', type=int, default=8)
        parser.add_argument('--max_len', type=int, default=512, help='max seq len')
        parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'FusedAdam', 'DeepSpeedCPUAdam'])
        parser.add_argument('--gpus', type=int, default=1, help='gpu 개수')
        return parser

    @staticmethod
    def add_test(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
        parser.add_argument('--ckptpath', type=str, default='', help='')
        parser.add_argument('--filepath', type=str, default='', help='')
        return parser

    @staticmethod
    def add_training(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
        parser.add_argument('--lr', type=float, default=1e-4)
        parser.add_argument('--warmup_ratio', type=float, default=0.1)
        parser.add_argument('--accumulate_grad', type=int, default=4, help='gradient accumulation')
        return parser

    def make_manual_config(self, save='config/config_manual.yaml', parser=None):
        if parser is not None:
            parser = argparse.ArgumentParser(parents=[parser], add_help=False)
        else:
            parser = argparse.ArgumentParser(description='manual args')
        parser = self.add_file_loading(parser)
        parser = self.add_specific(parser)
        parser = self.add_basic(parser)
        parser = self.add_test(parser)
        parser = self.add_training(parser)

        if save is not None:
            with open(save, 'w', encoding='utf-8') as f:
                yaml.dump(parser.parse_args().__dict__, f)
        return parser


def make_trainer_args(save='config/config_trainer.yaml'):
    # os.makedirs('config', exist_ok=True)
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    with open(save, 'w', encoding='utf-8') as f:
        yaml.dump(args.__dict__, f)


def object_hook(d):
    return SimpleNamespace(**d)


def _construct_mapping(loader, node):
    loader.flatten_mapping(node)
    return SimpleNamespace(**dict(loader.construct_pairs(node)))


class Loader(yaml.Loader):
    pass


Loader.add_constructor(
    yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _construct_mapping
)


def load_yaml(filename):
    import os
    with open(filename, 'r', encoding='utf-8') as f:
        # tmp = yaml.load(f, Loader=yaml.FullLoader)
        tmp = yaml.load(f, Loader=Loader)
    return tmp


def pick_newargs(parser):
    args_defaults = parser.__dict__['_actions']
    args_new = parser.parse_args()
    newargs = SimpleNamespace()

    for item in args_defaults:
        if item.option_strings[-1] != '--help':
            key = item.option_strings[0]
            key = key.replace('-', ' ').strip().replace(' ', '-')
            item_default = item.default
            item_new = getattr(args_new, key)
            if item_default != item_new:
                setattr(newargs, key, item_new)

    return newargs


def merge_args(pre_args: SimpleNamespace, post_args: SimpleNamespace):
    for key in post_args.__dict__.keys():
        setattr(pre_args, key, getattr(post_args, key))

    return pre_args


def config_loading(parser, config_manual, config_trainer):
    out_args = load_yaml(config_trainer)
    out_args = merge_args(pre_args=out_args, post_args=load_yaml(config_manual))
    out_args = merge_args(pre_args=out_args, post_args=pick_newargs(parser))
    return out_args


def args_to_config(config, args):
    for key in args.__dict__:
        try:
            getattr(config, key)
        except:
            print(':)')
    return 0


def print_args(args:dict, print_ft='logging'):
    ft = {'logging': logging.info,
          'print': print,
          'pprint': pprint}
    if isinstance(print_ft, str):
        print_ft = ft[print_ft]

    for key in args:
        print_ft(str(key) + ': ' + str(args[key]))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_manual', type=str, default='../config/config_manual.yaml')
    parser.add_argument('--config_trainer', type=str, default='../config/config_trainer.yaml')
    args = parser.parse_args()
    make_trainer_args(save=args.config_trainer)
    MA = ManualArgs().make_manual_config(save=args.config_manual)
