from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from transformers import BertConfig, RobertaConfig
from unilm.configuration_unilm import UnilmConfig

logger = logging.getLogger(__name__)


class UniLMFinetuningConfig(BertConfig):
    def __init__(self, label_smoothing=0.1, source_type_id=0, target_type_id=1, 
                 rel_pos_bins=0, max_rel_pos=0, fix_word_embedding=False, embedding_size=None,
                 expand_qk_head_dim=None, num_ffn_layers=1, need_pooler=False, **kwargs):
        super(UniLMFinetuningConfig, self).__init__(**kwargs)
        self.label_smoothing = label_smoothing
        self.source_type_id = source_type_id
        self.target_type_id = target_type_id
        self.max_rel_pos = max_rel_pos
        self.rel_pos_bins = rel_pos_bins
        self.fix_word_embedding = fix_word_embedding
        self.need_pooler = need_pooler
        self.expand_qk_head_dim = expand_qk_head_dim
        self.num_ffn_layers = num_ffn_layers
        if embedding_size is None:
            self.embedding_size = self.hidden_size
        else:
            self.embedding_size = embedding_size

    @classmethod
    def from_exist_config(
            cls, config, label_smoothing=0.1, max_position_embeddings=None,
            need_pooler=False, fix_word_embedding=False,
    ):
        required_keys = [
            "vocab_size", "hidden_size", "num_hidden_layers", "num_attention_heads",
            "hidden_act", "intermediate_size", "hidden_dropout_prob", "attention_probs_dropout_prob",
            "max_position_embeddings", "type_vocab_size", "initializer_range", "layer_norm_eps", 
        ]

        kwargs = {}
        for key in required_keys:
            assert hasattr(config, key)
            kwargs[key] = getattr(config, key)

        kwargs["vocab_size_or_config_json_file"] = kwargs["vocab_size"]
        if isinstance(config, RobertaConfig):
            kwargs["type_vocab_size"] = 0
            kwargs["max_position_embeddings"] = kwargs["max_position_embeddings"] - 2
        
        additional_keys = [
            "source_type_id", "target_type_id", "rel_pos_bins", "max_rel_pos",
            "expand_qk_head_dim", "num_ffn_layers", "embedding_size",
        ]
        for key in additional_keys:
            if hasattr(config, key):
                kwargs[key] = getattr(config, key)

        if max_position_embeddings is not None and max_position_embeddings > config.max_position_embeddings:
            kwargs["max_position_embeddings"] = max_position_embeddings
            logger.info("  **  Change max position embeddings to %d  ** " % max_position_embeddings)

        return cls(
            label_smoothing=label_smoothing, need_pooler=need_pooler,
            fix_word_embedding=fix_word_embedding, **kwargs)
