from typing import Union, Optional
from vllm import LLM, SamplingParams
import numpy as np
from .. import LanguageModel, GenerateOutput

class VLLMModel(LanguageModel):
    def __init__(
        self,
        model_pth,
        max_new_tokens=256,
        tensor_parrallel_size=4
    ):
        super().__init__()
        self.model = LLM(model=model_pth, tensor_parallel_size=tensor_parrallel_size,
                         trust_remote_code=True)
        self.tokenizer = self.model.get_tokenizer()
        self.max_new_tokens = max_new_tokens

    def generate(
        self,
        inputs: list[str], 
        max_new_tokens: Optional[int] = None,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 1.0,
        num_return_sequences: int = 1,
        **kwargs,
    ) -> GenerateOutput:
        if max_new_tokens is None:
            max_new_tokens = self.max_new_tokens
        sampling_params = SamplingParams(
            n = num_return_sequences,
            temperature=temperature, 
            top_p=top_p,
            top_k=top_k,
            stop="\n\n",
            max_tokens=max_new_tokens,
            seed=5378)
        texts = []
        response = self.model.generate(inputs, sampling_params, use_tqdm=False)
        for output in response[0].outputs:
            texts.append(output.text)
        return GenerateOutput(texts, None)

    def get_next_token_logits(self,
                              prompt: Union[str, list[str]],
                              candidates: Union[list[str], list[list[str]]],
                              **kwargs) -> list[np.ndarray]:
        
        raise NotImplementedError("Model does not support get_next_token_logits")

    def get_loglikelihood(self,
                    prompt: Union[str, list[str]],
                    **kwargs) -> list[np.ndarray]:
        
        raise NotImplementedError("Model does not support get_log_prob")