from transformers import BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, ElectraTokenizer
from unilm.tokenization_unilm import UnilmTokenizer
from unilm.tokenization_minilm import MinilmTokenizer


TOEKNIZER_CLASSES = {
    'bert': BertTokenizer,
    'minilm': MinilmTokenizer,
    'roberta': RobertaTokenizer,
    'xlm-roberta': XLMRobertaTokenizer,
    'unilm': UnilmTokenizer,
    'electra': ElectraTokenizer,
}


def get_tokenizer(tokenizer_name, do_lower_case=None, cache_dir=None, model_type=None):
    tokenizer_class = TOEKNIZER_CLASSES[model_type]

    tokenizer = tokenizer_class.from_pretrained(
        tokenizer_name, do_lower_case=do_lower_case, cache_dir=cache_dir)

    return tokenizer
