import logging
import json

logger = logging.getLogger(__name__)


def eval_str_list(x, type=float):
    if x is None:
        return None
    if isinstance(x, str):
        x = eval(x)
    try:
        return list(map(type, x))
    except TypeError:
        return [type(x)]


def add_optimzation_args(parser):
    # args for AdamW
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.01, type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--adam_betas', '--adam_beta', default='0.9,0.999', type=eval_str_list, metavar='B',
                        help='betas for Adam optimizer')
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--disable_bias_correct", action='store_true',
                        help="Disable the bias correction items. ")
    parser.add_argument("--layer_decay", default=1.0, type=float,
                        help="Layer decay rate for the layer-wise learning rate. ")

    # For FP16
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--fp16_init_loss_scale', type=float, default=128.0,
                        help="For fp16: initial value for loss scale.")

    # Finetuning settings:
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--num_training_steps", default=-1, type=int,
                        help="set total number of training steps to perform")
    parser.add_argument("--num_training_epochs", default=10, type=int,
                        help="set total number of training epochs to perform (--num_training_steps has higher priority)")
    parser.add_argument("--num_warmup_steps", default=None, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--warmup_ratio", default=0.1, type=float,
                        help="Linear warmup over warmup_ratio.")


def get_optimizer_grouped_parameters(
        model, weight_decay, learning_rate, layer_decay, n_layers):
    groups = {}
    num_max_layer = 0
    no_decay = ['bias', 'LayerNorm.weight']
    groups_keys = {}
    for para_name, para_var in model.named_parameters():
        if any(nd in para_name for nd in no_decay):
            weight_decay_in_this_group = 0.0
        else:
            weight_decay_in_this_group = weight_decay
        if para_name.startswith('bert.embedding') or para_name == 'bert.rel_pos_bias.weight':
            depth = 0
        elif para_name.startswith('bert.encoder.layer'):
            depth = int(para_name.split('.')[3]) + 1
            num_max_layer = max(num_max_layer, depth)
        else:
            depth = n_layers + 2

        if layer_decay < 1.0:
            group_name = "layer{}_decay{}".format(depth, weight_decay_in_this_group)
        else:
            group_name = "weight_decay{}".format(weight_decay_in_this_group)
        if group_name not in groups:
            group = {
                "params": [para_var],
                "weight_decay": weight_decay_in_this_group,
            }
            if layer_decay < 1.0:
                group["lr"] = learning_rate * (layer_decay ** (n_layers + 2 - depth))
            groups[group_name] = group
            groups_keys[group_name] = [para_name]
        else:
            group = groups[group_name]
            group["params"].append(para_var)
            groups_keys[group_name].append(para_name)

    assert num_max_layer == n_layers

    logger.info("Optimizer groups: = %s" % json.dumps(groups_keys, indent=2))

    return list(groups.values())
