from llama import Llama


class Llama2:
    def __init__(self, model_variant, batch_size):
        llama_path = "/path/to/Llama-2/"
        self.model_variant = model_variant
        self.model = Llama.build(
            ckpt_dir=llama_path + model_variant,
            tokenizer_path=llama_path + "tokenizer.model",
            max_seq_len=512,
            max_batch_size=batch_size,
        )

    def compute(self, prompts, sample_decimals=False):
        if self.model_variant.split("-")[-1] == "chat":
            formatted_prompts = [
                [{"role": "user", "content": prompt}] for prompt in prompts
            ]
            outputs = self.model.chat_completion(
                formatted_prompts, max_gen_len=64, temperature=1.0, top_p=0.6
            )
            outputs = [output["generation"]["content"] for output in outputs]
        else:
            outputs = self.model.text_completion(
                prompts, max_gen_len=64, temperature=1.0, top_p=0.6
            )
            outputs = [output["generation"] for output in outputs]
        return outputs
