from typing import Dict, Any

import transformers
from transformers import PretrainedConfig, AutoConfig, T5Tokenizer, BertTokenizerFast, T5TokenizerFast, BertTokenizer, \
    BertModel, LongT5EncoderModel, T5EncoderModel

from iter.modeling_features import FeaturesMixin


class ITERConfig(PretrainedConfig, FeaturesMixin):
    d_ff: int
    d_model: int
    num_types: int
    num_links: int
    features: int
    max_nest_depth: int
    dropout: float
    transformer_config: PretrainedConfig
    model_type = "iter"

    def __init__(
            self,
            transformer_name="t5-small",
            transformer_config=None,
            num_types=4,
            num_links=5,
            features=0,
            dataset: str = None,
            max_nest_depth: int = 1,
            dropout: float = 0.3,
            activation_fn: str = "gelu",
            **kwargs
    ):
        transformer_name = (transformer_config or {}).get("_name_or_path", transformer_name)
        self.transformer_config = AutoConfig.from_pretrained(transformer_name, **(transformer_config or {}))
        self.num_types = num_types
        self.num_links = num_links
        self.features = features
        self.dataset = dataset
        self.max_nest_depth = max_nest_depth
        self.dropout = dropout
        self.activation_fn = activation_fn
        self.d_ff = self.try_attr_options(
            "d_ff",  # T5 model family
            "intermediate_size",  # BERT model family
        )
        self.d_model = self.try_attr_options(
            "d_model",  # T5
            "hidden_size",  # BERT
        )

        super().__init__(
            is_encoder_decoder=False,
            **kwargs
        )
        self.max_length = self.transformer_config.max_length

    def guess_tokenizer_class(self, use_fast=False):
        if "t5" in self.transformer_config.model_type:
            return T5Tokenizer if not use_fast else T5TokenizerFast
        return BertTokenizer if not use_fast else BertTokenizerFast

    def guess_model_class(self):
        if "bert" == self.transformer_config.model_type:
            return BertModel
        elif "longt5" == self.transformer_config.model_type:
            return LongT5EncoderModel
        else:
            return T5EncoderModel

    def try_attr_options(self, *items):
        exceptions = []
        for item in items:
            try:
                return self.transformer_config.__getattribute__(item)
            except AttributeError as cause:
                exceptions.append(cause)
        raise AttributeError from exceptions

    @staticmethod
    def _get_generation_defaults() -> Dict[str, Any]:
        return {}