# from llama import Llama
from transformers import AutoTokenizer, AutoModelForCausalLM
import google.generativeai as genai
from openai import OpenAI
from typing import Union, List


def load_api_key(path, multiple_keys=False):
    if multiple_keys:
        keys = []
        with open(path, "r", encoding="utf-8") as f:
            keys = f.readlines()
            keys = [key.strip() for key in keys]
        return keys

    with open(path, "r", encoding="utf-8") as f:
        return f.read().strip()


class LLaMa:
    def __init__(self, version, checkpoint_path, max_tokens=2048, temperature=0.1):
        """
        Versions:
        Meta-Llama-3-70B-Instruct
        Meta-Llama-3-70B
        Meta-Llama-3-8B-Instruct
        Meta-Llama-3-8B
        """
        self.version = version
        self.checkpoint_path = checkpoint_path  # not used
        self.tokenizer = AutoTokenizer.from_pretrained(version)
        self.max_tokens = max_tokens
        self.temperature = temperature

        self.model = AutoModelForCausalLM.from_pretrained(
            version,
            # torch_dtype=torch.bfloat16,
            load_in_4bit=True,
            device_map="auto",
        )

    def ask(self, prompt: str, classes: Union[None, List[str]] = None) -> str:
        messages = [
            {"role": "user", "content": prompt},
        ]

        input_ids = self.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
        ]
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.max_tokens,
            eos_token_id=terminators,
            temperature=self.temperature,
            # logprobs=True,
        )

        if classes is None:
            response = outputs[0][input_ids.shape[-1] :]
            response = self.tokenizer.decode(response, skip_special_tokens=True)
            return response

        response = outputs[0][input_ids.shape[-1] :]
        response = self.tokenizer.decode(response, skip_special_tokens=True)
        return response


class ChatGPT:
    def __init__(self, version, api_key_path, max_tokens=4096, temperature=1e-6):
        """
        Versions:
        gpt-4-0125-preview, gpt-4-1106-preview
        gpt-4, gpt-4-32k
        gpt-3.5-turbo-0125, gpt-3.5-turbo-instruct
        """
        self.version = version
        self.api_key = load_api_key(api_key_path)
        self.max_tokens = max_tokens
        self.temperature = temperature

    def ask(self, prompt: str, classes: Union[None, List[str]] = None) -> str:
        client = OpenAI(api_key=self.api_key)
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=self.version,
            max_tokens=self.max_tokens,
            temperature=self.temperature,
            logprobs=True,
            top_logprobs=20,
        )

        if classes is None:
            response = chat_completion.choices[0].message.content
            return response

        class_logprobs = {}
        for candidate in classes:
            top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
            logprob = -float("inf")
            for logprob_obj in top_logprobs:
                if candidate.lower().startswith(logprob_obj.token.strip().lower()):
                    if logprob_obj.logprob > logprob:
                        logprob = logprob_obj.logprob
            class_logprobs[candidate] = logprob

        return max(class_logprobs, key=class_logprobs.get)


class Gemini:
    def __init__(self, version, api_key_path):
        """
        Versions:
        gemini-1.0-pro, gemini-1.0-pro-001
        gemini-1.0-pro-latest, gemini-1.0-pro-vision-latest
        gemini-pro, gemini-pro-vision
        """
        self.version = version
        self.api_keys = load_api_key(api_key_path, multiple_keys=True)
        if len(self.api_keys) == 0:
            raise ValueError("No API keys found.")
        self.cur_key_idx = 0

    def ask(self, prompt):
        # ValueError happens even if there is response and no safety issue raised
        safety_categories = [
            "HARM_CATEGORY_DANGEROUS",
            "HARM_CATEGORY_HARASSMENT",
            "HARM_CATEGORY_HATE_SPEECH",
            "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "HARM_CATEGORY_DANGEROUS_CONTENT",
        ]
        safety_settings = [
            {"category": category, "threshold": "BLOCK_NONE"}
            for category in safety_categories
        ]

        genai.configure(api_key=self.api_keys[self.cur_key_idx])
        self.cur_key_idx = (self.cur_key_idx + 1) % len(self.api_keys)
        print(self.cur_key_idx)
        model = genai.GenerativeModel(self.version)
        response = model.generate_content(prompt, safety_settings=safety_settings)
        # Avoid no response, because of safety checker
        if len(response.candidates) == 0:
            response = "None returned from Gemini"
        else:
            response = response.text

        return response


class LLM:
    def __init__(self, model: str, version: str, llm_path: str):
        """llm_path is the path to the API or the model itself."""
        self.name = model.lower()
        self.model = None

        if self.name == "chatgpt":
            self.model = ChatGPT(version, llm_path)
        # elif self.name == "gemini":
        #     self.model = Gemini(version, llm_path)
        # elif self.name == "llama":
        #     self.model = LLaMa(version, llm_path)
        elif self.name == "test":
            self.model = None
        else:
            raise ValueError(f"Unsupported language model: {self.name}")

    def generate(self, prompt: str, classes: Union[None, List[str]] = None) -> str:
        if self.name == "chatgpt":
            return self.model.ask(prompt, classes)
        # elif self.name == "gemini" or self.name == "llama":
        #     return self.model.ask(prompt)
        elif self.name == "test":
            return "<answer>Test response</answer>"
        else:
            raise ValueError(f"Unsupported language model: {self.name}")
