import asyncio
import os
import random
import time

import openai
from loguru import logger
from tqdm import tqdm

KEYs = [
    "Your keys",
]


class ChatGPT(object):
    def __init__(self, key, parser, batch_size, max_retries=10):
        openai.api_key = os.getenv("OPENAI_KEY", default=key)
        openai.api_base = "your base"
        self.batch_size = batch_size
        self.parser = parser
        self.max_retries = max_retries

    def batch_predict(self, samples, writer):
        holder = []
        loop = asyncio.get_event_loop()

        for i in tqdm(range(0, len(samples), self.batch_size)):
            holder.clear()
            batch = samples[i: i + self.batch_size]
            ret = [None] * self.batch_size

            for j, content in enumerate(batch):
                holder.append(self.send(j, content))

            if len(holder) > 0:
                results = loop.run_until_complete(asyncio.wait(holder))[0]
                for result in results:
                    idx, output = result.result()
                    ret[idx] = output

            writer.write_all(ret)
            time.sleep(1)
        loop.close()

    async def send(self, idx, content, exponential_base=2):
        time.sleep(idx * 1)
        delay = 1
        num_retry = 0
        while True:
            try:
                response = await openai.ChatCompletion.acreate(
                    model="gpt-3.5-turbo",
                    messages=content,
                    temperature=0.5,
                )
                break
            except (openai.OpenAIError, UnicodeDecodeError) as error:
                num_retry += 1
                if num_retry > self.max_retries:
                    return idx, {"error": str(error)}

                logger.warning(error)
                logger.warning("Retrying...")
                delay *= exponential_base * (1 + random.random())
                time.sleep(delay)

        result = None
        try:
            result = response.choices[0].message.content
            return idx, {"result": self.parser(result)}
        except BaseException as exception:
            logger.warning(exception)
            return idx, {"error": True, "content": result}
