import requests
import re
import time
import Evaluation.resources_usage as resources_usage


def query(prompt, model_key):
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 64,
            "return_full_text": False,
            "temperature": 1.0,
            "repetition_penalty": 100.0,
            "top_p": 0.6,
        }
    }
    API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
    headers = {"Authorization": f"Bearer {model_key}"}
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


class BLOOM():
    def __init__(self, model_key, batch_size):
        if (model_key != None):
            self.model_key = model_key
            self.type = "API"
        else:
            from transformers import AutoTokenizer
            from llm_serving.model.wrapper import get_model
            tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom")
            tokenizer.add_bos_token = False
            # Load the model. Alpa automatically downloads the weights to the specificed path
            t0 = time.time()
            model = get_model(model_name="alpa/bloom", path="/work/ML/Bloom", batch_size=batch_size)
            print(f"Done with Load the model: {(time.time() - t0) / 60} minutes")
            resources_usage.print_cpu_usage()
            resources_usage.print_memory_usage()
            resources_usage.print_GPU_usage()
            self.model = model
            self.tokenizer = tokenizer
            self.type = "Model"

    def compute(self, prompts, sample_decimals=False):
        if (self.type != "API"):
            input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids
            outputs = self.model.generate(
                input_ids=input_ids,
                max_length=int(len(input_ids[0])) + 64,
                temperature=1.0,
                top_p=0.6,
                do_sample=True
            )
            outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        else:
            outputs = query(prompts, self.model_key)
            outputs = [o['generated_text'] for o in outputs]
        return outputs
