import openai
import os
import time
from common.registry import registry
import pdb
import tiktoken
from openai import OpenAI

import anthropic
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

@registry.register_llm("gpt")
class OPENAI_GPT:
    def __init__(self,
                 engine="gpt-3.5-turbo-0631",
                 temperature=0,
                 max_tokens=200,
                 use_azure=False,
                 top_p=1,
                 stop=["\n"],
                 retry_delays=60, # in seconds
                 max_retry_iters=7,
                 context_length=16384,
                 system_message=''
                 ):
        
        
        self.engine = engine
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.use_azure = use_azure
        self.top_p = top_p
        self.stop = stop
        self.retry_delays = retry_delays
        self.max_retry_iters = max_retry_iters
        self.context_length = context_length
        self.system_message = system_message
        self.init_api_key()


        
    def init_api_key(self):
        if self.use_azure:
            openai.api_type = os.environ['OPENAI_API_TYPE']
            openai.api_version = os.environ['OPENAI_API_VERSION']
        else:
            if 'OPENAI_API_KEY' not in os.environ:
                raise Exception("OPENAI_API_KEY environment variable not set.")
            else:
                openai.api_key = os.environ['OPENAI_API_KEY']

    def llm_inference(self, messages_input, model_name, system_message):
        openai_api_key = 'sk-...'
        claude_api_key = 'sk-...'
        mixtral_api_key = 'sk-...'

        if model_name in ['gpt-4-1106-preview', 'gpt-4', 'gpt-4-32k', 'gpt-3.5-turbo-0301', 'gpt-4-0613',
                          'gpt-4-32k-0613', 'gpt-3.5-turbo-16k-0613']:

            openai_api_key_name = openai_api_key
            messages = [{"role": "system", "content": system_message}] + messages_input
            client = OpenAI(api_key=openai_api_key_name)
            response = client.chat.completions.create(
                model=model_name,
                messages=messages,
                temperature=0.0,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0
            )
            return response.choices[0].message.content
        elif model_name in ["claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-haiku-20240307"]:
            claude_api_key_name = claude_api_key
            client = anthropic.Anthropic(
                # defaults to os.environ.get("ANTHROPIC_API_KEY")
                api_key=claude_api_key_name,
            )
            message = client.messages.create(
                model=model_name,  # claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-haiku-20240307
                max_tokens=1000,
                temperature=0.0,
                system=system_message,
                messages=messages_input
            )
            return message.content[0].text
        elif model_name in ['open-mixtral-8x7b', "mistral-large-latest"]:
            mixtral_api_key_name = mixtral_api_key

            messages = [ChatMessage(role="system", content=system_message)] + messages_input
            api_key = mixtral_api_key_name
            client = MistralClient(api_key=api_key)
            messages = messages
            # No streaming
            chat_response = client.chat(
                model=model_name,
                messages=messages,
            )
            return chat_response.choices[0].message.content

    def generate(self, system_message, prompt, model_name_testLLM):
        if model_name_testLLM in ['open-mixtral-8x7b', "mistral-large-latest"]:
            messages = [
                ChatMessage(role="user", content=prompt)
            ]
        else:
            messages = [
                {"role": "user", "content": prompt}
            ]

        for attempt in range(self.max_retry_iters):
            #return True, self.llm_inference(messages, model_name_testLLM, system_message)

            try:
                action = self.llm_inference(messages, model_name_testLLM, system_message)
                print(f'\nLLM generated action: {action}\n')
                return True, action  # return success, completion
            except Exception as e:
                print(f"Error on attempt {attempt + 1}")
                if attempt < self.max_retry_iters - 1:  # If not the last attempt
                    time.sleep(60)  # Wait before retrying

                else:
                    print("Failed to get completion after multiple attempts.")
                    # raise e

        return False, ''

    def num_tokens_from_messages(self, messages, model="gpt-3.5-turbo-0613"):
        """Return the number of tokens used by a list of messages."""
        model = self.engine
        try:
            encoding = tiktoken.encoding_for_model(model)
        except KeyError:
            print("Warning: model not found. Using cl100k_base encoding.")
            encoding = tiktoken.get_encoding("cl100k_base")
        
        tokens_per_message = 0
        tokens_per_name = 0
        if model in {
            "gpt-3.5-turbo-0613",
            "gpt-3.5-turbo-16k-0613",
            "gpt-4-0314",
            "gpt-4-32k-0314",
            "gpt-4-0613",
            "gpt-4-32k-0613",
            }:
            tokens_per_message = 3
            tokens_per_name = 1
        
        num_tokens = 0
        for message in messages:
            num_tokens += tokens_per_message
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":
                    num_tokens += tokens_per_name
        num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
        return num_tokens

    @classmethod
    def from_config(cls, config):
        
        engine = config.get("engine", "gpt-35-turbo")
        temperature = config.get("temperature", 0)
        max_tokens = config.get("max_tokens", 100)
        system_message = config.get("system_message", "You are a helpful assistant.")
        use_azure = config.get("use_azure", True)
        top_p = config.get("top_p", 1)
        stop = config.get("stop", ["\n"])
        retry_delays = config.get("retry_delays", 10)
        context_length = config.get("context_length", 4096)
        return cls(engine=engine,
                   temperature=temperature,
                   max_tokens=max_tokens,
                   use_azure=use_azure,
                   top_p=top_p,
                   retry_delays=retry_delays,
                   system_message=system_message,
                   context_length=context_length,
                   stop=stop)
