import time
from typing import Any, Callable, Hashable, List, Optional

import numpy as np
from substrate import LLMClient, SubstrateRateLimitError
"""SUBSTRATE FILE IS USED TO CALL API LLM FOR THE RESPONSE"""

def get_unique_in_order(
    lst: List[Any], key: Optional[Callable[[Any], Hashable]] = None
):
    uniq = []
    seen = set()
    for elem in lst:
        k = key(elem) if key is not None else elem
        if k not in seen:
            seen.add(k)
            uniq.append(elem)
    return uniq


LLM_CLIENT = LLMClient(use_dev_capacity=True)
DEFAULT_SLEEP = 120  # seconds


def _wrapped_llm_client_send_request(request_data, model):
    if model == "gpt4-turbo":
        model = "dev-gpt-4-turbo"
    if model == "gpt3.5-turbo":
        model = "dev-gpt-35-turbo"
    if model == "phi":
        model = "dev-phi-3-medium-128k-instruct"
    if model == "mistral":
        model = "dev-mistral-7b-instruct-v02"
    if model == "gpt4-o":
        model = "dev-gpt-4o-2024-05-13"
    # three attempts
    for calls_left in reversed(range(0, 8)):
        try:
            response = LLM_CLIENT.send_request(model, request_data)
            return response
        except SubstrateRateLimitError as err:
            print(err)
            seconds = DEFAULT_SLEEP
            print(f"Sleeping {seconds}s due to rate limit (tries left={calls_left})")
            time.sleep(seconds)
    return None


def get_llm_response(
    model: str,
    prompt: str,
    max_tokens: Optional[int] = 250,
    temperature: Optional[float] = 0.0,
    n: Optional[int] = 1,
    stop: Optional[List[str]] = None,
    logprobs: Optional[int] = 1,
    all_resp: Optional[int] = 0,
) -> List[str]:
    """Get completion for a prompt."""
    stop = stop or "\n"
    attempt = 3
    while attempt:
        attempt -= 1
        request_data = {
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": 1,
            "n": n,
            "stream": False,
            "logprobs": logprobs,
            "stop": stop,
        }
        raw_response = _wrapped_llm_client_send_request(request_data, model)
        if all_resp and "choices" in raw_response:
            return [i["text"] for i in raw_response["choices"]]
        if (
            raw_response is not None
            and isinstance(raw_response, dict)
            and "choices" in raw_response
        ):
            unique_choices = get_unique_in_order(
                raw_response["choices"], key=lambda x: x["text"]
            )
            unique_choices = [
                entry for entry in unique_choices if len(entry["text"].strip()) > 0
            ]
            unique_choices = sorted(
                unique_choices,
                key=lambda x: np.mean(x["logprobs"]["token_logprobs"]),
                reverse=True,
            )
            return [e["text"] for e in unique_choices]
    return None


def testing():
    prompt = "What is the meaning of life?"
    # response = get_llm_response("gpt4-turbo", prompt)
    # response = get_llm_response("gpt4-o", prompt)
    # response = get_llm_response("phi", prompt)
    response = get_llm_response("mistral", prompt)
    print(response)
    return


if __name__ == "__main__":
    # try:
    testing()
# except Exception as err:
#     import pdb
#     pdb.post_mortem()
