import openai
import os
import json
import xxhash
import logging

from random import Random
from sqlitedict import SqliteDict
from transformers import GPT2TokenizerFast

MAX_PROMPT_LENGTH = 2048

# Setup
openai.api_key  = os.environ['OPENAI_API_KEY']
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# OpenAI helper functions
def hash_32(v):
    data = json.dumps(v, sort_keys=True)
    return xxhash.xxh32(data).intdigest()

def hash_hexdigest(v):
    data = json.dumps(v, sort_keys=True)
    return xxhash.xxh32(data).hexdigest()

def get_response_text(response):
    return response.choices[0].text.strip()

def get_response_texts(response):
    return [choice.text.strip() for choice in response.choices]

def count_tokens(texts):
    if type(texts) == str:
        texts = [texts]
    old_log_level = logging.getLogger("transformers.tokenization_utils_base").level
    logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
    max_tokens = max(map(len, tokenizer(texts)['input_ids']))
    logging.getLogger("transformers.tokenization_utils_base").setLevel(old_log_level)
    return max_tokens

def _count_tokens(texts):
    if type(texts) == str:
        texts = [texts]
    old_log_level = logging.getLogger("transformers.tokenization_utils_base").level
    logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
    return list(map(len, tokenizer(texts)['input_ids']))

def compute_max_prompt_length(max_tokens):
    return MAX_PROMPT_LENGTH - max_tokens

def truncate(text, max_length=MAX_PROMPT_LENGTH):
    text_tokens = count_tokens(text)
    if text_tokens > max_length:
        keep_amount = max_length / count_tokens
        return text[:int(len(text) * keep_amount)]
    else:
        return text

def prompt_packer(example_template, examples, query_template, query, prompt_prefix='', separator = '\n', shuffle_examples=True, max_prompt_length=MAX_PROMPT_LENGTH):
    rnd = Random(hash_32(examples))
    
    # Setup
    separator_tokens = count_tokens(separator)
    tokens_remaining = max_prompt_length

    # Account for the prefix
    if len(prompt_prefix) > 0:
        tokens_remaining -= count_tokens(prompt_prefix) + separator_tokens
        prompt_prefix = prompt_prefix + separator
    
    # Fill in the query
    query_prompt = query_template % query
    tokens_remaining -= count_tokens(query_prompt)

    # Get the examples
    example_prompts = [example_template % example for example in examples]
    
    # Only keep as many examples that will fit
    final_example_prompts = []
    for example_prompt in example_prompts:
        example_prompt_tokens = count_tokens(example_prompt) + count_tokens(separator)
        if (tokens_remaining - example_prompt_tokens) >= 0:
            final_example_prompts.append(example_prompt)
            tokens_remaining -= example_prompt_tokens
        else:
            break
    
    # Shuffle the final example prompts
    if shuffle_examples:
        rnd.shuffle(final_example_prompts)

    return (
        prompt_prefix + 
        separator.join(final_example_prompts) + 
        (separator if len(final_example_prompts) > 0 else '') + 
        query_prompt
    )

def prompt_packer_batch(example_templates, examples, query_templates, queries, prompt_prefix='', separator = '\n', shuffle_examples=True, max_prompt_length=MAX_PROMPT_LENGTH):
    return [
        prompt_packer(
            example_template,
            _examples,
            query_template,
            query,
            prompt_prefix=prompt_prefix,
            separator=separator,
            shuffle_examples=shuffle_examples,
            max_prompt_length=max_prompt_length
        ) for example_template, _examples, query_template, query in zip(
            example_templates,
            examples,
            query_templates,
            queries,
        )
    ]

def create_cache(cache_path):
    return SqliteDict(cache_path, autocommit=True)

def cached_completion(engine, prompts, temperature=0.7, max_tokens=256, top_p=1.0, frequency_penalty=0.0, presence_penalty=0.0, stop=None, cache=None):
    is_single_prompt = type(prompts) == str

    # Make sure min is within range
    max_tokens = min(MAX_PROMPT_LENGTH, max_tokens)

    # Convert to list
    if is_single_prompt:
        prompts = [prompts]

    # Compute hashes of prompts
    hash_keys = [
        hash_hexdigest({
            'method': 'cached_completion',
            'engine': engine,
            'prompt': prompt,
            'temperature': temperature,
            'max_tokens': max_tokens,
            'top_p': top_p,
            'frequency_penalty': frequency_penalty,
            'presence_penalty': presence_penalty,
            'stop': stop,
        }) for prompt in prompts
    ]

    # Retrieve cached responses if some prompts were previously cached
    if cache is not None:
        responses = [
            (cache[hash_key] if hash_key in cache else None)
            for prompt, hash_key in zip(prompts, hash_keys)
        ]
        prompts = list(reversed([prompt for i, prompt in enumerate(prompts) if hash_keys[i] not in cache]))
    else:
        responses = [None] * len(prompts)
        prompts = list(reversed(prompts))

    # Send the rest to the API
    if len(prompts) > 0:
        api_responses = [t for t in get_response_texts(
            openai.Completion.create(
                engine=engine,
                prompt=prompts,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop
            )
        )]
    else:
        api_responses = []
    
    # Fill in the missing responses with responses from the API
    for (i, response), hash_key in zip(enumerate(responses), hash_keys):
        if response is None:
            responses[i] = api_responses.pop()
            if cache is not None:
                cache[hash_key] = responses[i]

    # Return the responses
    return responses[0] if is_single_prompt else responses