from typing import Union
import os
import time

try:
    from . import llm_cache
except:
    import llm_cache

def generate_logprob(client, prompts: list[str], **kwargs):
    """
    client: AI21BedrockClient/OpenAI client
    prompts: list[str]
    completion_kwargs: dict
    output:
    texts: list[str]
    text_offsets: list[list[int]]
    token_logprobs: list[list[float]]
    tokens: list[list[str]]
    """
    superclasses = client.__class__.__mro__
    s_str = [str(s) for s in superclasses]
    if any(["AI21BedrockClient" in s for s in s_str]):
        return generate_bedrock_logprob(client, prompts, **kwargs)
    elif any(["OpenAI" in s for s in s_str]):
        return generate_openai_logprob(client, prompts, **kwargs)
    else:
        raise ValueError("client must be an instance of AI21BedrockClient or OpenAI")


def generate_openai_logprob(
    client,
    prompts: list[str],
    completion_kwargs: dict = {
        "model": None,
    },
):
    completion = llm_cache.cached_client_completion(
        client,
        prompt=prompts,
        logprobs=0,
        max_tokens=0,
        echo=True,
        **completion_kwargs,
    )
    texts = [c.text for c in completion.choices]
    text_offsets = [c.logprobs.text_offset for c in completion.choices]
    token_logprobs = [c.logprobs.token_logprobs for c in completion.choices]
    tokens = [c.logprobs.tokens for c in completion.choices]
    try:
        assert (
            texts == prompts
        ), "The returned texts are not the same as the input prompts. %s != %s" % (
            texts,
            prompts,
        )
    except AssertionError:
        if any([c.finish_reason == "content_filter" for c in completion.choices]):
            # some text was censored, but logprobs and text_offsets may still be useful
            assert [len(t) for t in text_offsets] == [len(p) for p in token_logprobs]
            for i in range(len(prompts)):
                assert max(text_offsets[i]) >= len(prompts[i]) - 10
        else:
            raise
    return texts, text_offsets, token_logprobs, tokens


def generate_bedrock_logprob(
    client, prompts: list[str], completion_kwargs: dict = {"model": None}
):
    from concurrent.futures import ThreadPoolExecutor

    responses = []
    if len(prompts) ==1:
        responses.append(
            llm_cache.cached_client_completion(
                client,
                prompt=prompts[0],
                # cannot use 0. at least 1
                max_tokens=1,
                **completion_kwargs,
            )
        )
    else:
        with ThreadPoolExecutor() as executor:
            for p in prompts:
                responses.append(
                    executor.submit(
                        llm_cache.cached_client_completion,
                        client,
                        prompt=p,
                        # cannot use 0. at least 1,
                        # but we will handle it later in `check_is_random`
                        max_tokens=0,
                        **completion_kwargs,
                    )
                )
            responses = [r.result() for r in responses]
    texts = [r.prompt.text for r in responses]
    assert (
        # this is not always true, as Jurassic might swap tokens like '。' -> '.'
        #  texts == prompts
        [len(t) for t in texts]
        == [len(p) for p in prompts]
    ), "The returned texts are not the same as the input prompts. %s != %s" % (
        texts,
        prompts,
    )
    text_offsets = [
        [t["textRange"]["start"] for t in r.prompt.tokens] for r in responses
    ]
    token_logprobs = [
        [t["generatedToken"]["raw_logprob"] for t in r.prompt.tokens] for r in responses
    ]
    tokens = [
        [t["generatedToken"]["token"] for t in r.prompt.tokens] for r in responses
    ]
    return texts, text_offsets, token_logprobs, tokens


def generate_qa(client, **kwargs):
    """
    client: AI21BedrockClient/OpenAI client
    prompt: str
    messages: list[dict]
    completion_kwargs: dict
    output:
    response: str
    """
    superclasses = client.__class__.__mro__
    s_str = [str(s) for s in superclasses]
    if any(["OpenAI" in s for s in s_str]):
        return generate_openai_qa(client, **kwargs)
    else:
        raise ValueError("Unsupported client for QA")


def generate_openai_qa(
    client,
    prompt: str = None,
    messages: list[dict] = None,
    completion_kwargs: dict = {
        "model": None,
    },
):
    if prompt is not None and messages is not None:
        raise ValueError("Only one of prompt or messages should be provided.")
    if prompt is None and messages is None:
        raise ValueError("Either prompt or messages should be provided.")
    common_kwargs = {
        "max_tokens": 512,
    }
    if prompt is not None:
        completion = llm_cache.cached_client_completion(
            client,
            prompt=prompt,
            **common_kwargs,
            **completion_kwargs,
        )
        response = completion.choices[0].text
    else:
        completion = llm_cache.cached_client_completion(
            client,
            messages=messages,
            **common_kwargs,
            **completion_kwargs,
        )
        response = completion.choices[0].message.content
    return response
