# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import os
import numpy as np

from pytorch_transformers.tokenization_bert import BertTokenizer
from torch import nn

from pytorch_transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_transformers.optimization import AdamW

patterns_optimizer = {
    'additional_layers': ['additional'],
    'top_layer': ['additional', 'encoder.layer.11.'],
    'top4_layers': [
        'additional',
        'encoder.layer.11.',
        'encoder.layer.10.',
        'encoder.layer.9.',
        'encoder.layer.8',
    ],
    'all_encoder_layers': ['additional', 'encoder.layer'],
    'all': ['additional', 'encoder.layer', 'embeddings'],
}

def get_bert_optimizer(models, learning_rate, type_optimization='all_encoder_layers'):
    """ Optimizes the network with AdamWithDecay
    """
    if type_optimization not in patterns_optimizer:
        print(
            'Error. Type optimizer must be one of %s' % (str(patterns_optimizer.keys()))
        )
    parameters_with_decay = []
    parameters_with_decay_names = []
    parameters_without_decay = []
    parameters_without_decay_names = []
    parameters_tur = []
    parameters_encoder = []
    no_decay = ['bias', 'gamma', 'beta']

    patterns = patterns_optimizer[type_optimization]
    patterns_tur = ['A'] 
    patterns_all = patterns + patterns_tur

    for model in models:
        for n, p in model.named_parameters():
            if any(t in n for t in patterns_tur):
                parameters_tur.append(p)

            if any(t in n for t in patterns):
                parameters_encoder.append(p)

            if any(t in n for t in patterns_all):
                if any(t in n for t in no_decay):
                    parameters_without_decay.append(p)
                    parameters_without_decay_names.append(n)
                else:
                    parameters_with_decay.append(p)
                    parameters_with_decay_names.append(n)

    print('The following parameters will be optimized WITH decay:')
    print(ellipse(parameters_with_decay_names, 5, ' , '))
    print('The following parameters will be optimized WITHOUT decay:')
    print(ellipse(parameters_without_decay_names, 5, ' , '))

    optimizer_grouped_parameters = [
        {'params': parameters_with_decay, 'weight_decay': 0.01},
        {'params': parameters_without_decay, 'weight_decay': 0.0},
    ]
    optimizer = AdamW(
        optimizer_grouped_parameters, 
        lr=learning_rate, 
        correct_bias=False
    )

    return optimizer, parameters_encoder, parameters_tur


def ellipse(lst, max_display=5, sep='|'):
    """
    Like join, but possibly inserts an ellipsis.
    :param lst: The list to join on
    :param int max_display: the number of items to display for ellipsing.
        If -1, shows all items
    :param string sep: the delimiter to join on
    """
    # copy the list (or force it to a list if it's a set)
    choices = list(lst)
    # insert the ellipsis if necessary
    if max_display > 0 and len(choices) > max_display:
        ellipsis = '...and {} more'.format(len(choices) - max_display)
        choices = choices[:max_display] + [ellipsis]
    return sep.join(str(c) for c in choices)
