from transformers import ElectraForSequenceClassification
from adapters import ElectraAdapterModel

import logging

log = logging.getLogger(__name__)


def build_model(model_class, model_path_or_name, **kwargs):
    return model_class.from_pretrained(model_path_or_name, **kwargs)


def create_electra(model_config, tokenizer, use_sngp, use_duq, use_spectralnorm,
                   use_mixup, use_selective, ue_args, model_path_or_name, config):
    model_kwargs = dict(
        from_tf=False,
        config=model_config,
        cache_dir=config.cache_dir,
    )
    
    if "adapters" in config:
        electra_classifier = ElectraAdapterModel
    else:
        electra_classifier = ElectraForSequenceClassification

    model = build_model(electra_classifier, model_path_or_name, **model_kwargs)

    return model
