import os
import inspect
import torch
from transformers import AutoModel, AutoTokenizer

from transformers import logging as tf_logging
tf_logging.set_verbosity_error()

multiRegressive = ['facebook/nllb-200-distilled-600M', 'google/mt5-base', 't5-base', 'google/byt5-base', 'sberbank-ai/mGPT', 'Helsinki-NLP/opus-mt-mul-en', 'google/canine-s', 'google/canine-c', 'facebook/xglm-564M', 'facebook/xglm-564M', 'facebook/mgenre-wiki', 'setu4993/LaBSE', 'bigscience/bloom-560m', 'facebook/mbart-large-50']
multiAutoencoder = ["microsoft/mdeberta-v3-base", "studio-ousia/mluke-large", "google/rembert", "cardiffnlp/twitter-xlm-roberta-base", "xlm-roberta-large", "bert-base-multilingual-cased", "xlm-roberta-base", 'distilbert-base-multilingual-cased', 'microsoft/infoxlm-large', 'bert-base-multilingual-uncased', 'Peltarion/xlm-roberta-longformer-base-4096', 'Peltarion/xlm-roberta-longformer-base-4096', 'studio-ousia/mluke-base', 'xlm-mlm-100-1280'] # 'facebook/xlm-roberta-xxl', 'facebook/xlm-roberta-xl'
multiEmbeds = multiRegressive + multiAutoencoder

text = 'this is a test'
for embed in multiEmbeds:

    tokenizer = AutoTokenizer.from_pretrained(embed, use_fast=False)
    model = AutoModel.from_pretrained(embed)
    tokked = tokenizer.encode(text, return_tensors='pt')


    args = {'input_ids': tokked, 'output_hidden_states':True}
    argspec = inspect.getfullargspec(model.forward)
    if 'decoder_input_ids' in argspec[0]:
        batch_size = 1
        decoder_start_token_id = model.config.bos_token_id
        if decoder_start_token_id == None:
            decoder_start_token_id = model.config.decoder_start_token_id
        decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=tokked.device) * decoder_start_token_id
        args['decoder_input_ids'] = decoder_input_ids


    output = model.forward(**args)

    print(embed, [x for x in output])

