from typing import Iterable

from transformers import AutoModelForMaskedLM, pipeline

import hf_utils
from bert_model import BertModel
from train_temporal_bert import ModelArguments


def load_model(model_name_or_path, expect_times_in_model=True):
    model_args = ModelArguments(model_name_or_path=model_name_or_path)
    config_kwargs = {}
    model, tokenizer = hf_utils.load_pretrained_model(
        model_args,
        AutoModelForMaskedLM,
        expect_times_in_model=expect_times_in_model,
        **config_kwargs,
    )
    return model, tokenizer


class Tester:
    def __init__(self, model, device=-1, preload=False) -> None:
        hf_utils.prepare_tf_classes()
        if not isinstance(model, Iterable):
            model = [model]
        model_tokenizer_list = (
            load_model(m, expect_times_in_model=False) for m in model
        )
        if preload:
            model_tokenizer_list = list(model_tokenizer_list)
        self.fill_mask_pipelines = (
            pipeline("fill-mask", model=model, tokenizer=tokenizer, device=device)
            for model, tokenizer in model_tokenizer_list
        )
        if preload:
            self.fill_mask_pipelines = list(self.fill_mask_pipelines)
        self.bert_models = (
            BertModel(hf_pipeline=fill_mask_pipeline, device=device)
            for fill_mask_pipeline in self.fill_mask_pipelines
        )
        if preload:
            self.bert_models = list(self.bert_models)
