# -*- coding: utf-8 -*-
__author__ = "Yash Kumar Lal, Github@ykl7"

import os
import openai
from openai import OpenAI
import pickle
import time
import random
import pandas as pd
import hashlib
import atexit
from config import config

random.seed(1234)

class OpenAICommunicator():

    def __init__(self, options):

        atexit.register(self.cleanup)

        self.client = OpenAI(api_key=config["OPENAI_API_KEY"])
        self.model_name = options["model_name"]
        self.max_tokens = options["max_tokens"]
        self.cache_path = '../data/cache/openai_cache.pkl' if "cache_path" not in options else options["cache_path"]
        self.temp = 0.0 if "temperature" not in options else options["temperature"]
        self.top_p = 1.0 if "top_p" not in options else options["top_p"]
        self.frequency_penalty = 0.0 if "frequency_penalty" not in options else options["frequency_penalty"]
        self.presence_penalty = 0.0 if "presence_penalty" not in options else options["presence_penalty"]
        self.cached_responses = self.load_cache_if_exists()

    def load_cache_if_exists(self):
        if os.path.exists(self.cache_path):
            with open(self.cache_path, 'rb') as handle:
                cache_file = pickle.load(handle)
                return cache_file
        else:
            os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
            return {}

    def cleanup(self):
        # works similar to a destructor but does not offload builtins.open method.
        print(f"\nFinal cleanup cache saving..", end="...\n")
        with open(self.cache_path, 'wb') as handle:
            pickle.dump(self.cached_responses, handle)

    def make_openai_chat_completions_api_call(self, prompt):
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=prompt,
                temperature=self.temp,
                max_tokens=self.max_tokens,
                top_p=self.top_p,
                frequency_penalty=self.frequency_penalty,
                presence_penalty=self.presence_penalty
            )
            return self.parse_chat_completions_api_response(response)
        except openai.APIConnectionError as e:
            print("The server could not be reached")
            print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            time.sleep(60)
            return self.make_openai_api_call(prompt)
        except openai.RateLimitError as e:
            print("Rate limit error hit")
            exit()
        except openai.NotFoundError as e:
            print("Model not found")
            exit()
        except openai.APIStatusError as e:
            print("Another non-200-range status code was received")
            print(e.status_code)
            print(e.response)
            time.sleep(60)
            return self.make_openai_api_call(prompt)

    def parse_chat_completions_api_response(self, response):
        choices = response.choices
        main_response = choices[0].message
        main_response_message, main_response_role = main_response.content, main_response.role
        return main_response_message, response

    def run_inference(self, prompt=[], use_cache=True):

        if not use_cache:
            response_text, response = self.make_openai_chat_completions_api_call(prompt)
            return response_text

        hashed_prompt = hashlib.sha256(str(prompt).encode("utf-8")).hexdigest()
        cache_key = (hashed_prompt, self.model_name, self.max_tokens, self.temp, self.top_p, self.frequency_penalty, self.presence_penalty)
        if cache_key in self.cached_responses:
            response_text = self.cached_responses[cache_key]['text']
        else:
            response_text, response = self.make_openai_chat_completions_api_call(prompt)
            self.cached_responses[cache_key] = {'text': response_text}
            # with open(self.cache_path, 'wb') as handle:
            #     pickle.dump(self.cached_responses, handle)
            # time.sleep(2)

        return response_text

if __name__ == '__main__':

    options = {}
    options["model_name"] = "gpt-3.5-turbo-0301"
    options["max_tokens"] = 50

    openai_communicator = OpenAICommunicator(options)
    openai_communicator.run_inference()
