import torch
from sentence_transformers import SentenceTransformer

from .data_processor import CmdCaliperDataProcessor
from .models import CmdCaliper

class HuggingFaceInferencer:
    def __init__(self, model_name, device):
        self.model = SentenceTransformer(model_name)
        self.model.to(device)
        
        self.device = device

    def __call__(self, sentence_list: list):
        return torch.tensor(self.model.encode(sentence_list), device=self.device)

class CmdCaliperInferencer:
    def __init__(self, model_name, path_to_checkpoint, device):
        model_class = CmdCaliper
        tokenizer = model_class.get_tokenizer(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.unk_token
        self.data_processor = CmdCaliperDataProcessor(
            tokenizer, device
        )

        self.model = model_class(
            device, 
            load_from_pretrained=False, 
            path_to_model_weight=model_name
        )
        self.model.load_state_dict(torch.load(path_to_checkpoint))
        self.model.to(device)
        self.model.eval()


    def __call__(self, sentence_list: list):
        sentence_tokens_info = self.data_processor(sentence_list)
        return self.model(**sentence_tokens_info)
