import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import utils.gpt3

tokenizer = None
model = None
generator = None
SOFT_MAX_PROMPT_LENGTH = 2000

def load_engine(engine):
    global tokenizer
    global model
    global generator
    global SOFT_MAX_PROMPT_LENGTH 
    if not tokenizer:
        tokenizer = AutoTokenizer.from_pretrained(engine)
        if 'bigscience/T0' in engine or 'google/flan-t5' in engine:
            utils.gpt3.MAX_PROMPT_LENGTH = 1024
            SOFT_MAX_PROMPT_LENGTH = 700
            model = AutoModelForSeq2SeqLM.from_pretrained(engine, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float32)
            # model.to(torch.device(0))
        else:
            if 'gpt2' in engine:
                utils.gpt3.MAX_PROMPT_LENGTH = 1024
                SOFT_MAX_PROMPT_LENGTH = 700
            model = AutoModelForCausalLM.from_pretrained(engine, torch_dtype=torch.float16)
            generator = transformers.pipeline(task="text-generation", model=model, tokenizer=tokenizer, device=torch.device(0))

def completion(prompts, temperature=0.7, max_tokens=256, top_p=1.0, frequency_penalty=None, stop=None):
    is_single_prompt = type(prompts) == str
    
    # Convert to list
    if is_single_prompt:
        prompts = [prompts]
    
    generations = []
    for prompt in prompts:
        if generator:
            prompt_length = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
            r = generator(prompt, max_length=prompt_length+max_tokens, use_cache=True, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=frequency_penalty)
            generation = r[0]['generated_text'].replace(prompt, '')
        else:
            inputs = tokenizer.encode(prompt, return_tensors="pt").to(torch.device(0))
            outputs = model.generate(inputs, max_length=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=frequency_penalty)
            generation = tokenizer.decode(outputs[0].cpu(), skip_special_tokens=True)
        if stop:
            for stop_chars in stop:
                if stop_chars in generation:
                    generation = generation[:generation.index(stop_chars)]
        generations.append(generation)
    
    # Return the responses
    return generations[0] if is_single_prompt else generations