import openai
import os
import time 
import json
from openai import OpenAI
import base64
import requests
import transformers
import openai
import torch

class ChatGPT:

    def __init__(self, model_name):
        
        self.model_name = model_name
        self.prompt_token=0
        self.gen_token=0
        self.cost=0
         
        openai.api_key = os.getenv("OPENAI_API_KEY")
        self.api_key = os.getenv("OPENAI_API_KEY")
        self.sleep_time = 0.5
        self.client = OpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"),
            )
        
        self.prompt_token_cost=0
        self.gen_token_cost=0
        
        if(self.model_name=="gpt-3.5-turbo-0613"):
            self.prompt_token_cost=1.5
            self.gen_token_cost=2

        elif(self.model_name=="gpt-3.5-turbo-1106"):
            self.prompt_token_cost=1
            self.gen_token_cost=2

        elif(self.model_name=="gpt-3.5-turbo-0125"):
            self.prompt_token_cost=0.5
            self.gen_token_cost=1.5

        elif(self.model_name=="gpt-4"):
            self.prompt_token_cost=30
            self.gen_token_cost=60
        
        elif(self.model_name=="gpt-4-0613"):
            self.prompt_token_cost=30
            self.gen_token_cost=60

        elif(self.model_name=="gpt-4-0125-preview"):
            self.prompt_token_cost=10
            self.gen_token_cost=30

        elif(self.model_name=="gpt-4-1106-preview"):
            self.prompt_token_cost=10
            self.gen_token_cost=30

        elif(self.model_name=="gpt-4-vision-preview"):
            self.prompt_token_cost=10
            self.gen_token_cost=30
            

    def generate(self,prompt,sys_prompt=None):

        message_list = []

        if(sys_prompt is not None):
            message_list.append({"role": "system", "content": sys_prompt})

        message_list.append({"role": "user", "content": prompt})

        try:
            response_object = self.client.chat.completions.create(
                model=self.model_name,
                messages=message_list
                )
        except:
            time.sleep(self.sleep_time)
            response_object = self.client.chat.completions.create(
                model=self.model_name,
                messages=message_list
                )

        response = response_object.choices[0].message.content
        self.prompt_token+=response_object.usage.prompt_tokens
        self.gen_token+=response_object.usage.completion_tokens
        
        return response

    def generate_role_force(self,sysprompt,dialog_history,seeker_is_user=True):

        messages=[{"role": "system", "content": sysprompt}]

        if(seeker_is_user):
            role_list = ["user","assistant"]
        else:
            role_list = ["assistant","user"]

        for i in range(len(dialog_history)):
            messages.append({
                "role":role_list[i%2],
                "content":dialog_history[i]
            })

        try:
            response_object = openai.ChatCompletion.create(
                model=self.model_name,
                messages=messages
                )
        except:
            time.sleep(self.sleep_time)
            response_object = openai.ChatCompletion.create(
                model=self.model_name,
                messages=messages
                )

        response = response_object['choices'][0]['message']['content']

        self.prompt_token+=response_object.usage.prompt_tokens
        self.gen_token+=response_object.usage.completion_tokens
        
        return response

    def generate_with_function(self,prompt,functions,function_call="auto"):

        messages = [
            {"role":"system","content":"Only use the valid functions and function paramters you are provided with."},
            {"role": "user", "content": prompt}
        ]

        try:
            response_object = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                functions=functions,
                function_call=function_call
                )
        except:
            time.sleep(self.sleep_time)
            response_object = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                functions=functions,
                function_call=function_call
                )

        response = response_object.choices[0].message.content

        self.prompt_token+=response_object.usage.prompt_tokens
        self.gen_token+=response_object.usage.completion_tokens
        
        return response_object.choices[0].message.function_call.arguments

    def encode_image(self,image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')

    def get_image_to_text(self,prompt,image_path):

        base64_image = self.encode_image(image_path)

        headers = {
          "Content-Type": "application/json",
          "Authorization": f"Bearer {self.api_key}"
        }

        payload = {
          "model": "gpt-4-vision-preview",
          "messages": [
            {
              "role": "user",
              "content": [
                {
                  "type": "text",
                  "text": prompt
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                  }
                }
              ]
            }
          ],
          "max_tokens": 600
        }

        try:
            response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            response  = response.json()
            response_text = response['choices'][0]['message']['content']
        except:
            time.sleep(self.sleep_time)
            try:
                response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
                response  = response.json()
                response_text = response['choices'][0]['message']['content']
            except:
                print("Failed to process",image_path)
                return ""
        
        self.prompt_token+=response['usage']['prompt_tokens']
        self.gen_token+=response['usage']['completion_tokens']
        return response_text
    
    def get_cost(self):
        
        cost = (self.prompt_token_cost*self.prompt_token+self.gen_token_cost*self.gen_token)/1000000
        
        return cost


class FastChat:

    def __init__(self, model_name):

        self.model_name = model_name
        openai.api_key = "EMPTY"
        openai.base_url = "http://localhost:8000/v1/"

    def generate(self,prompt):

        message_list = []
        message_list.append({"role": "user", "content": prompt})

        response_object = openai.chat.completions.create(
          model=self.model_name,
          messages=message_list
        )

        response = response_object.choices[0].message.content
        return response

    def get_cost(self):

        return 0

class LLAMA:

    def __init__(self, model_name):

        self.model_name = model_name
        self.pipeline = transformers.pipeline(
            "text-generation",
            model=model_name,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device="cuda",
        )

    def generate(self,prompt,sys_prompt=None):

        messages = []

        if(sys_prompt is not None):
            messages.append({"role": "system", "content": sys_prompt})
        messages.append({"role": "user", "content": prompt})

        processed_prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        terminators = [
            self.pipeline.tokenizer.eos_token_id,
            self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.pipeline(
            processed_prompt,
            max_new_tokens=256,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
        )

        return outputs[0]["generated_text"][len(processed_prompt):]

    def get_cost(self):

        return 0

        
