from diskcache import Cache
import os

use_cache = [True]

try:
    import experiments

    cache = Cache(
        os.path.join(os.path.dirname(experiments.__file__), "..", "llm_cache"),
        size_limit=1024 * 1024 * 1024 * 200,  # 200GB max
    )
except:
    cache = Cache("./test_cache")


def is_openai_client(client):
    superclasses = client.__class__.__mro__
    s_str = [str(s) for s in superclasses]
    return any(["OpenAI" in s for s in s_str])


def is_ai21bedrock_client(client):
    superclasses = client.__class__.__mro__
    s_str = [str(s) for s in superclasses]
    return any(["AI21BedrockClient" in s for s in s_str])


def check_is_random(client, **kwargs):
    if kwargs.get("seed", None) is not None:
        return False
    if is_ai21bedrock_client(client):
        # not pure function, because max_tokens cannot use 0. modified to 1
        if kwargs.get("max_tokens", None) == 0:
            kwargs["max_tokens"] = 1
            return False
        return True
    elif is_openai_client(client):
        return kwargs.get("max_tokens", None) != 0


def _real_client_completion(client, **kwargs):
    if is_ai21bedrock_client(client):
        return client.completion.create(**kwargs)
    elif is_openai_client(client):
        if "messages" in kwargs:
            return client.chat.completions.create(**kwargs)
        elif "prompt" in kwargs:
            return client.completions.create(**kwargs)
        else:
            raise ValueError("Unknown request type. Please provide prompt or messages.")
    raise ValueError("Unknown client type.")


from . import change_model, timed_llm

real_client_completion = change_model.wrap_client_completion(
    timed_llm.wrap_client_completion(_real_client_completion)
)


@cache.memoize(ignore=[0, "client"])
def _cached_client_completion(client, _client_type, **kwargs):
    return real_client_completion(client, **kwargs)


def cached_client_completion(client, **kwargs):
    if not check_is_random(client, **kwargs) and use_cache[0]:
        return _cached_client_completion(client, str(client.__class__), **kwargs)
    else:
        return real_client_completion(client, **kwargs)
