from transformers import AutoTokenizer
import transformers
import torch

from samplers import DecimalWarper


class PipelineGeneration:
    def __init__(self, eval_model, model_variant):
        self.eval_model = eval_model
        if eval_model == "llama2":
            model = "meta-llama/" + model_variant
        elif eval_model == "falcon":
            model = "tiiuae/" + model_variant

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.pipeline = transformers.pipeline(
            "text-generation",
            model=model,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
        )

    def compute(self, prompts, sample_decimals):
        max_new_tokens, logits_processor_list = self.get_sampling_method(
            sample_decimals, self.eval_model
        )
        outputs = self.pipeline(
            prompts,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=1.0,
            top_p=0.6,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            return_full_text=False,
            logits_processor=logits_processor_list,
        )
        outputs = [output[0]["generated_text"] for output in outputs]
        return outputs

    def get_sampling_method(self, sample_decimals, eval_model):
        if sample_decimals:
            logits_processor_list = transformers.LogitsProcessorList()
            decimal_warper = DecimalWarper(self.tokenizer, eval_model)
            logits_processor_list.append(decimal_warper)
            return 1, logits_processor_list
        else:
            return 64, None
