# -*- coding: utf-8 -*-

from ...common.dataclass_options import BranchSelect
from .bert_wrapper import BERTPlugin
from .external_embedding import ExternalEmbeddingPlugin

__all__ = ['ExternalEmbeddingPlugin', 'PretainedPluginOptions']


def smartly_remove_weight_decay(named_parameters):
    decay_parameters = []
    non_decay_parameters = []
    for name, param in named_parameters:
        if not param.requires_grad:
            continue

        if 'bias' in name or 'norm' in name:
            non_decay_parameters.append(param)
        else:
            decay_parameters.append(param)

    return [{'params': non_decay_parameters, 'weight_decay': 0},
            {'params': decay_parameters}]


class PretainedPluginOptions(BranchSelect):
    type = 'none'
    branches = {'bert': BERTPlugin}
