#Based on HuggingFace - modified version. 

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json",
    "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json",
    "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json",
    "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json",
    "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json",
}

class GPT2EConfig(PretrainedConfig):
    """modified config class, including CoreLM specific parameters. """

    model_type = "gpt2"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_ctx=1024,
        n_embd=768,
        n_layer=12,
        n_head=12,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=0.1,
        gradient_checkpointing=False,
        use_cache=True,
        bos_token_id=50256,
        eos_token_id=50256,
        n_ent=500000,
        et_mod=13,
        n_ent_head=12,
        eg_mod=True,
        gate_control=0.5,
        freeze_percentage=0, 
        freeze_emb=False, 
        freeze_pos=False,
        freeze_ln=True,
        freeze_ff=True,
        freeze_attn=True,
        freeze_gate=False,
        freeze_entities=True, 
        freeze_lm=False,
        **kwargs
    ):
        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

        self.vocab_size = vocab_size
        self.n_ctx = n_ctx
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.summary_type = summary_type
        self.summary_use_proj = summary_use_proj
        self.summary_activation = summary_activation
        self.summary_first_dropout = summary_first_dropout
        self.summary_proj_to_labels = summary_proj_to_labels
        self.gradient_checkpointing = gradient_checkpointing
        self.use_cache = use_cache

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        #entities
        self.n_ent = n_ent
        self.et_mod = et_mod
        self.n_ent_head = n_ent_head
        self.eg_mod = eg_mod
        self.gate_control = gate_control

        #freezing
        self.freeze_percentage = freeze_percentage #if 0 different method
        self.freeze_ln = freeze_ln
        self.freeze_emb = freeze_emb
        self.freeze_pos = freeze_pos
        self.freeze_ff = freeze_ff
        self.freeze_attn = freeze_attn
        self.freeze_gate = freeze_gate
        self.freeze_entities = freeze_entities
        self.freeze_lm = freeze_lm


    @property
    def max_position_embeddings(self):
        return self.n_positions

    @property
    def hidden_size(self):
        return self.n_embd

    @property
    def num_attention_heads(self):
        return self.n_head

    @property
    def num_hidden_layers(self):
        return self.n_layer

    def get_freeze_params(self):
        return {'percent':self.freeze_percentage,
        'ln':self.freeze_ln,
        'emb':self.freeze_emb,
        'pos':self.freeze_pos,
        'ff':self.freeze_ff,
        'attn':self.freeze_attn,
        'gate':self.freeze_gate,
        'entities':self.freeze_entities,
        'lm':self.freeze_lm}