import argparse
from typing import List, Dict

import torch
import os.path as op

import yaml
from omegaconf import OmegaConf, DictConfig

CONFIG_HIERARCHY = {}

def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", required=True,
                        default=argparse.SUPPRESS, type=str,
                        help="config file")
    parser.add_argument("--project",
                        default=argparse.SUPPRESS, type=str,
                        help="project name for weight and bias")
    parser.add_argument("--name",
                        default=argparse.SUPPRESS, type=str,
                        help="run name for weight and bias")
    parser.add_argument("--data_dir",
                        default=argparse.SUPPRESS, type=str,
                        help="The input data dir with all required files.")
    parser.add_argument("--train_yaml",
                        default=argparse.SUPPRESS, type=str,
                        help="yaml file for training.")
    parser.add_argument("--test_yaml",
                        default=argparse.SUPPRESS, type=str,
                        help="yaml file for testing.")
    parser.add_argument("--val_yaml",
                        default=argparse.SUPPRESS, type=str,
                        help="yaml file used for validation during training.")
    parser.add_argument("--model_name_or_path",
                        default=argparse.SUPPRESS, type=str,
                        help="Path to pre-trained model or model type.")
    parser.add_argument("--output_dir",
                        default=argparse.SUPPRESS, type=str,
                        help="The output directory to save checkpoint and test results.")
    parser.add_argument("--loss_type",
                        default=argparse.SUPPRESS, type=str,
                        help="Loss function types: support kl, x2, sfmx")
    parser.add_argument("--config_name",
                        default=argparse.SUPPRESS, type=str,
                        help="Pretrained config name or path if not the same as model_name.")
    parser.add_argument("--tokenizer_name",
                        default=argparse.SUPPRESS, type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name.")
    parser.add_argument("--max_seq_length",
                        default=argparse.SUPPRESS, type=int,
                        help="The maximum total input sequence length after tokenization. "
                             "Sequences longer than this will be truncated, "
                             "sequences shorter will be padded.")
    parser.add_argument("--max_seq_a_length",
                        default=argparse.SUPPRESS, type=int,
                        help="The maximum sequence length for caption.")
    parser.add_argument("--do_train",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to run training.")
    parser.add_argument("--do_test",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to run inference.")
    parser.add_argument("--do_eval",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to run evaluation.")
    parser.add_argument("--do_lower_case",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--mask_prob",
                        default=argparse.SUPPRESS, type=float,
                        help="Probability to mask input sentence during training.")
    parser.add_argument("--max_masked_tokens",
                        default=argparse.SUPPRESS, type=int,
                        help="The max number of masked tokens per sentence.")
    parser.add_argument("--add_od_labels",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to add object detection labels or not")
    parser.add_argument("--drop_out",
                        default=argparse.SUPPRESS, type=float, help="Drop out in BERT.")
    parser.add_argument("--max_img_seq_length",
                        default=argparse.SUPPRESS, type=int,
                        help="The maximum total input image sequence length.")
    parser.add_argument("--img_feature_dim",
                        default=argparse.SUPPRESS, type=int,
                        help="The Image Feature Dimension.")
    parser.add_argument("--img_feature_type",
                        default=argparse.SUPPRESS, type=str,
                        help="Image feature type. (e.g., frcnn)")
    parser.add_argument("--tie_weights",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to tie decoding weights to that of encoding")
    parser.add_argument("--freeze_embedding",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Whether to freeze word embeddings in Bert")
    parser.add_argument("--freeze_backbone",
                        default=argparse.SUPPRESS, type=eval, choices=[True, False],
                        help="Whether to freeze Bert backbone")
    parser.add_argument("--label_smoothing",
                        default=argparse.SUPPRESS, type=float,
                        help=".")
    parser.add_argument("--drop_worst_ratio",
                        default=argparse.SUPPRESS, type=float,
                        help=".")
    parser.add_argument("--drop_worst_after",
                        default=argparse.SUPPRESS, type=int,
                        help=".")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=argparse.SUPPRESS, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=argparse.SUPPRESS, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--output_mode",
                        default=argparse.SUPPRESS, type=str,
                        help="output mode, support classification or regression.")
    parser.add_argument("--num_labels",
                        default=argparse.SUPPRESS, type=int,
                        help="num_labels is 2 for classification and 1 for regression.")
    # Optimizer Settings
    parser.add_argument('--gradient_accumulation_steps',
                        default=argparse.SUPPRESS, type=int,
                        help="Number of updates steps to accumulate before backward.")
    parser.add_argument("--learning_rate",
                        default=argparse.SUPPRESS, type=float,
                        help="The initial lr.")
    parser.add_argument("--weight_decay",
                        default=argparse.SUPPRESS, type=float,
                        help="Weight deay.")
    parser.add_argument("--adam_epsilon",
                        default=argparse.SUPPRESS, type=float,
                        help="Epsilon for Adam.")
    parser.add_argument("--max_grad_norm",
                        default=argparse.SUPPRESS, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_step_ratio",
                        default=argparse.SUPPRESS, type=float,
                        help="Linear warmup.")
    parser.add_argument("--plateau_step_ratio",
                        default=argparse.SUPPRESS, type=float,
                        help="Linear warmup.")
    parser.add_argument("--scheduler",
                        default=argparse.SUPPRESS, type=str,
                        help="constant or linear or plateau_linear")
    parser.add_argument("--num_workers",
                        default=argparse.SUPPRESS, type=int,
                        help="Workers in dataloader.")
    parser.add_argument("--num_train_epochs",
                        default=argparse.SUPPRESS, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps",
                        default=argparse.SUPPRESS, type=int,
                        help="Total number of training steps. Override num_train_epochs.")
    parser.add_argument('--logging_steps',
                        default=argparse.SUPPRESS, type=int,
                        help="Log every X steps.")
    parser.add_argument('--save_steps',
                        default=argparse.SUPPRESS, type=int,
                        help="Save checkpoint every X steps. Will also perform evaluatin.")
    parser.add_argument("--evaluate_during_training",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Run evaluation during training at each save_steps.")
    parser.add_argument("--no_cuda",
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Avoid using CUDA.")
    parser.add_argument('--seed',
                        default=argparse.SUPPRESS, type=int,
                        help="random seed for initialization.")
    # for self-critical sequence training
    parser.add_argument('--scst',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help='Self-critical sequence training')
    parser.add_argument('--sc_train_sample_n',
                        default=argparse.SUPPRESS, type=int,
                        help="number of sampled captions for sc training")
    parser.add_argument('--sc_baseline_type',
                        default=argparse.SUPPRESS, type=str,
                        help="baseline tyep of REINFORCE algorithm (e.g., greedy)")
    parser.add_argument('--sc_beam_size',
                        default=argparse.SUPPRESS, type=int,
                        help="beam size for scst training")
    parser.add_argument('--cider_cached_tokens',
                        default=argparse.SUPPRESS, type=str,
                        help="path to cached cPickle file used to calculate CIDEr scores")
    # for generation
    parser.add_argument("--eval_model_dir",
                        default=argparse.SUPPRESS, type=str,
                        help="Model directory for evaluation.")
    parser.add_argument('--max_gen_length',
                        default=argparse.SUPPRESS, type=int,
                        help="max length of generated sentences")
    parser.add_argument('--output_hidden_states',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Turn on for fast decoding")
    parser.add_argument('--output_attentions',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Turn on for visualizing the attention map.")
    parser.add_argument('--num_return_sequences',
                        default=argparse.SUPPRESS, type=int,
                        help="repeating times per image")
    parser.add_argument('--num_beams',
                        default=argparse.SUPPRESS, type=int,
                        help="beam search width")
    parser.add_argument('--num_keep_best',
                        default=argparse.SUPPRESS, type=int,
                        help="number of hypotheses to keep in beam search")
    parser.add_argument('--temperature',
                        default=argparse.SUPPRESS, type=float,
                        help="temperature in softmax for sampling")
    parser.add_argument('--top_k',
                        default=argparse.SUPPRESS, type=int,
                        help="filter distribution for sampling")
    parser.add_argument('--top_p',
                        default=argparse.SUPPRESS, type=float,
                        help="filter distribution for sampling")
    parser.add_argument('--repetition_penalty',
                        default=argparse.SUPPRESS, type=int,
                        help="repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)")
    parser.add_argument('--length_penalty',
                        default=argparse.SUPPRESS, type=int,
                        help="beam search length penalty")
    # for Constrained Beam Search
    parser.add_argument('--use_cbs',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help='Use constrained beam search for decoding')
    parser.add_argument('--min_constraints_to_satisfy',
                        default=argparse.SUPPRESS, type=int,
                        help="minimum number of constraints to satisfy")

    # Add [prefix] Token into task.
    parser.add_argument('--add_prefix',  type=eval, choices=[True, False],
                        default=argparse.SUPPRESS,
                        help="add [prefix] token with [SEP] in training and inference.")
    parser.add_argument('--prefix_no_pos_emb',
                        default=argparse.SUPPRESS, type=str,
                        help="can be 'no_pos' or 'avg_pos'.")
    parser.add_argument('--num_prefix',
                        default=argparse.SUPPRESS, type=int,
                        help="the number of prefix token. it can be int number.")
    parser.add_argument('--mask_inter_prefix',
                        default=argparse.SUPPRESS, type=eval, choices=[True, False],
                        help="disabling the attention across prefixs.")
    parser.add_argument('--freeze_prefix',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Freezing the prefix vectors during training")
    parser.add_argument('--mlp_for_prefix',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="put linear layer for prefix")
    parser.add_argument('--prefix_drop_prob',
                        default=argparse.SUPPRESS, type=float,
                        help="put linear layer for prefix")
    parser.add_argument('--prefix_shuffle_prob',
                        default=argparse.SUPPRESS, type=float,
                        help="put linear layer for prefix")

    # For eval with ema model:
    parser.add_argument('--eval_ema_num',
                        default=argparse.SUPPRESS, type=str,
                        help="add [prefix] token with [SEP] in training and inference.")

    # for num tags to engage in calculation
    parser.add_argument('--num_tags',
                        default=argparse.SUPPRESS, type=int,
                        help="minimum number of constraints to satisfy")
    parser.add_argument('--force_seq_len',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="force gen_caption with desired tag length")

    # Delete Inter-domain attention
    parser.add_argument('--mask_c_t',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask caption / tag attention")
    parser.add_argument('--mask_c_i',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask caption / image attention")
    parser.add_argument('--mask_t_t',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask tag / tag attention")
    parser.add_argument('--mask_t_i',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask tag / image attention")
    parser.add_argument('--mask_i_t',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask image / tag attention")
    parser.add_argument('--mask_i_i',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="mask image / image attention")

    # Activate Inter-domain attention
    parser.add_argument('--un_mask_t_c',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="unmask tag / caption attention")
    parser.add_argument('--un_mask_i_c',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="unmask image / caption attention")

    # to see tag dependency
    parser.add_argument('--tag_sep_token',
                        default=argparse.SUPPRESS,  type=eval, choices=[True, False],
                        help="Turn on/off tag sep toekn")
    parser.add_argument('--tag_entire_set',
                        default=argparse.SUPPRESS, type=str,
                        help="Get tags from all the coco train and validation")

    parser.add_argument("script_args", nargs=argparse.REMAINDER, help="Override config by CLI")

    args = parser.parse_args()
    default_config = OmegaConf.load(args.cfg)
    config = override_config_by_cli(default_config, vars(args),)
    config = OmegaConf.to_container(config, resolve=True)
    config = argparse.Namespace(**config)
    return config

def override_config_by_cli(base_cfg: DictConfig, cli_args:Dict) -> DictConfig:
    cli_args = OmegaConf.create(cli_args)
    cfg = OmegaConf.merge(base_cfg, cli_args)
    return cfg

def restore_training_settings(args, logger):\
    # Train(CE or SCST) / Inference Setup
    if args.do_train:
        # Train Setup
        checkpoint = args.model_name_or_path
        if args.scst is False:
            # During Training CE, Arguments are NOT Restored.
            override_params = []
            caption_length = args.max_seq_a_length
            logger.warning("Train With CE: Does Not Restore Any of Previous Arguments.")
        else:
            # During Train SCST, Arguments Restored
            caption_length = args.max_gen_length
            override_params = ['do_lower_case', 'add_od_labels', 'max_img_seq_length']
            logger.warning("Train With SCST: Recstore Previous Arguments.")
    else:
        # Inference Setup
        # During Inference, Arguments are Restored.
        assert args.do_test or args.do_eval
        checkpoint = args.eval_model_dir
        caption_length = args.max_gen_length
        override_params = ['do_lower_case', 'add_od_labels', 'max_img_seq_length']
        logger.warning("Inference: Recstore Previous Arguments.")

    # Load restore training settings
    if 'bert' not in checkpoint:
        train_args = torch.load(op.join(checkpoint, 'training_args.bin'))
        logger.info(f"Loading a Arguments from: {checkpoint}")
    else:
        train_args = {}

    # prefix Sequence Length Setup
    if args.add_prefix is True:
        max_od_labels_len = args.num_prefix
        args.max_seq_length = caption_length + max_od_labels_len
        args.prefix_type_now = "prefix_Emb"

        logger.info("With [prefix] token")
        logger.warning(f'Due to [prefix] token, max_seq_length ({args.max_seq_length}) = '
                       f'caption length ({caption_length}) + '
                       f'[prefix] Token ({args.num_prefix})')

        if hasattr(train_args, "prefix_type_now"):
            # Restore Previous prefix Type
            args.prefix_type_last = train_args.prefix_type_now
            logger.warning(f"prefix from checkpoint.")

        else:
            # Initialize prefix
            args.prefix_type_last = None
            logger.warning(f"First initialize prefix.")

    # TAG Model Sequence Length Setup
    else:
        # Restore All training Properties when force_seq_len is not activated.
        if args.force_seq_len is False:
            if hasattr(train_args, 'max_seq_a_length'):
                if hasattr(train_args, 'scst') and train_args.scst:
                    # If Previous Training was SCST
                    max_od_labels_len = train_args.max_seq_length - train_args.max_gen_length
                else:
                    # If Previous Training was CE
                    max_od_labels_len = train_args.max_seq_length - train_args.max_seq_a_length

                max_seq_length = args.max_gen_length + max_od_labels_len
                args.max_seq_length = max_seq_length
                logger.warning('Override max_seq_length to {} = max_gen_length:{} + od_labels_len:{}'.format(
                    max_seq_length, args.max_gen_length, max_od_labels_len))

            override_params = override_params + ['max_seq_a_length']

        # Restore Properties Except Seq Len.
        else:
            max_od_labels_len = args.max_seq_length - args.max_gen_length
            logger.warning("Force Max Gen Length ")
            logger.info('Max_seq_length {} = max_gen_length:{} + od_labels_len:{}'.format(
                args.max_seq_length, args.max_gen_length, max_od_labels_len))

    # Update current arguments with restored arguments
    for param in override_params:
        if hasattr(train_args, param):
            train_v = getattr(train_args, param)
            test_v = getattr(args, param)
            if train_v != test_v:
                logger.warning(f'Override {param} with train args: {test_v} -> {train_v}')
                setattr(args, param, train_v)
    return args

def save_arg(args=None, output_dir="."):
    if args is None:
        args = get_arguments()
        args = vars(args)

    new_cfg_file = output_dir+"/config.yaml"

    with open(new_cfg_file, 'w') as outfile:
        yaml.dump(args, outfile, sort_keys=False)

    return new_cfg_file

if __name__ == "__main__":
    _ = save_arg(output_dir="./config/")