import os
import time
from typing import List, Dict

import openai
from dotenv import load_dotenv
from tenacity import wait_exponential, retry, stop_after_attempt, RetryCallState

from tmp import re_extract


def clean_query(s: str):
    if s.startswith("'") or s.startswith('"'):
        s = s[1:]
    if s.endswith("'") or s.endswith('"'):
        s = s[:-1]
    return s


def clean_queries(ss: List[str]):
    return [clean_query(s) for s in ss]


def clean_response(s: str):
    # reg = rf"\[+([\w\W]+)\]+"
    if "Expand[Query]: " in s:
        s = s.split("Expand[Query]: ")[-1]
    if "[" in s and "]" in s:
        s = s.split("[")[-1].split("]")[0]
        clean_res = clean_queries(s.split(", "))
    elif "[" in s:
        s = s.split("[")[-1]
        clean_res = clean_queries(s.split(", "))
    elif "]" in s:
        s = s.split("]")[0]
        clean_res = clean_queries(s.split(", "))
    else:
        clean_res = s
    return clean_res


def check_before(retry_state: RetryCallState):
    if retry_state.attempt_number == 1:
        retry_state.args[0].request_timeout = 3
    if (
        openai.api_key == retry_state.args[0].api_key_openai
        and retry_state.attempt_number <= 3
    ):
        print("Restoring API key... to Azure")
        openai.api_key = retry_state.args[0].api_key
        openai.api_type = retry_state.args[0].api_type
        openai.api_base = retry_state.args[0].api_base
        openai.api_version = retry_state.args[0].api_version


def log_attempt_number(retry_state: RetryCallState):
    """return the result of the last call attempt"""
    if type(retry_state.outcome.exception()) is openai.error.Timeout:
        print("Increasing timeout...")
        retry_state.args[0].request_timeout += 1
        return

    if 1 < retry_state.attempt_number <= 3:
        if openai.api_key != retry_state.args[0].api_key_sub:
            print("Switching API key... => sub")
            openai.api_key = retry_state.args[0].api_key_sub
        else:
            print("Switching API key... => main")
            openai.api_key = retry_state.args[0].api_key

    elif retry_state.attempt_number > 3:
        print("Switching API key... => openai")
        openai.api_key = retry_state.args[0].api_key_openai
        openai.api_type = "open_ai"
        openai.api_base = "https://api.openai.com/v1/"
        # openai.api_base = None
        openai.api_version = None

    print(f"Retrying: {retry_state.attempt_number}...")


class GPT:
    # text-davinci-003 is also good but expensive.
    # gpt-3.5-turbo is set as default.
    def __init__(
        self,
        model_name: str = "gpt-3.5-turbo-0301",
        max_token_length: int = 200,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: List[str] = None,
        max_iter: int = 1,
    ):
        self.stop = stop
        if self.stop is None:
            self.stop = ["###"]

        self.engine_name = None
        self.model_name = model_name
        self.max_token_length = max_token_length
        self.temperature = temperature
        self.top_p = top_p

        self.max_iter = max_iter

        load_dotenv()
        self.api_key = os.getenv("OPENAI_API_KEY")
        openai.api_key = self.api_key

        if os.getenv("OPENAI_API_KEY_SUB") is not None:
            self.api_key_sub = os.getenv("OPENAI_API_KEY_SUB")

        if os.getenv("OPENAI_API_KEY_OPENAI") is not None:
            self.api_key_openai = os.getenv("OPENAI_API_KEY_OPENAI")

        if os.getenv("OPENAI_API_TYPE") is not None:
            self.api_type = os.getenv("OPENAI_API_TYPE")
            openai.api_type = self.api_type
        if os.getenv("OPENAI_API_BASE") is not None:
            self.api_base = os.getenv("OPENAI_API_BASE")
            openai.api_base = self.api_base
        if os.getenv("OPENAI_API_VERSION") is not None:
            self.api_version = os.getenv("OPENAI_API_VERSION")
            openai.api_version = self.api_version

        if os.getenv("OPENAI_API_TYPE") == "azure":
            self.engine_name = os.getenv("OPENAI_API_ENGINE")
        print()
        self.request_timeout = 3

    @retry(
        wait=wait_exponential(multiplier=10, min=5, max=10),
        stop=stop_after_attempt(7),
        before=check_before,
        after=log_attempt_number,
    )
    def query(self, prompt: List[Dict[str, str]]):
        response = openai.ChatCompletion.create(
            model=self.model_name if openai.api_key == self.api_key_openai else None,
            engine=self.engine_name if openai.api_key != self.api_key_openai else None,
            messages=prompt,
            max_tokens=self.max_token_length,
            temperature=self.temperature,
            top_p=self.top_p,
            stop=self.stop,
            request_timeout=self.request_timeout,
        )
        return response["choices"][0]["message"]["content"]
