import yaml
import os
import tiktoken
import joblib
# from langchain import OpenAI, Wikipedia
# from langchain.chat_models import ChatOpenAI
from open_source_models.llama import load_llama
import openai
import google.generativeai as genai
import httpx
import logging

default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

def _load_yaml(file_path):
    returned = None
    with open(file_path, "r") as stream:
        try:
            returned = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    return returned

# def _load_model(configs):
#     model_name = configs["model_name"]
#     if "gpt" in model_name:
#         temp = configs['temperature']
#         max_tokens = configs['max_tokens']
#         model_kwargs = configs['model_kwargs']
#         max_retries = configs['max_retries']
#         n = configs['n']
#         return ChatOpenAI(temperature=temp,
#             max_tokens=max_tokens,
#             model_name=model_name,
#             n = n,
#             model_kwargs=model_kwargs,
#             max_retries = max_retries,
#             openai_api_key=os.environ['OPENAI_API_KEY'])
#     elif "llama" in model_name:
#         ckpt_dir = model_name
#         n_gpus = configs['n_gpus']
#         devices_ids = configs['device_ids']
#         model,tokenizer = load_llama(ckpt_dir, True, n_gpus, devices_ids = devices_ids)
#         return (model, tokenizer)
#     elif len(model_name)==0:
#         return None
#     elif "text-davinci" in model_name:
#         temp = configs['temperature']
#         max_tokens = configs['max_tokens']
#         model_kwargs = configs['model_kwargs']
#         n = configs['n']
#         return OpenAI(temperature=temp,
#             max_tokens=max_tokens,
#             model_name=model_name,
#             n = n,
#             model_kwargs=model_kwargs,
#             openai_api_key=os.environ['OPENAI_API_KEY'])  
#     else:
#         temp = configs['temperature']
#         max_tokens = configs['max_tokens']
#         model_kwargs = configs['model_kwargs']
#         n = configs['n']
#         return OpenAI(temperature=temp,
#             max_tokens=max_tokens,
#             model_name=model_name,
#             n = n,
#             model_kwargs=model_kwargs,
#             openai_api_key=os.environ['OPENAI_API_KEY'])
def _load_model(configs, API_BASE="Default", API_KEY="Default", organization=False):
    name = configs["model_name"]
    client = None
    if "gpt" in name:
        API_BASE = "https://api.openai.com/v1" if API_BASE == "Default" else API_BASE
        API_KEY = os.getenv("BEN_OPENAI_API") if API_KEY == "Default" else API_KEY
        if organization:
            client = openai.OpenAI(
                base_url=API_BASE,
                api_key=API_KEY,
                organization=os.getenv("OPENAI_DUKE_ORG_KEY"),
                timeout=httpx.Timeout(300.0, read=50.0, write=20.0, connect=10.0),
            )
        else:
            client = openai.OpenAI(
                base_url=API_BASE,
                api_key=API_KEY,
                timeout=httpx.Timeout(300.0, read=50.0, write=20.0, connect=10.0),
            )
    elif "gemini" in name:
        API_KEY = os.getenv("GOOGLE_API") if API_KEY == "Default" else API_KEY
        genai.configure(api_key=API_KEY)
        generation_config = {
            "temperature": 0,
            "top_p": 1,
            "top_k": 1,
            "max_output_tokens": 2048,
            }
        thresh = "BLOCK_MEDIUM_AND_ABOVE"
        # thresh = 'BLOCK_LOW_AND_ABOVE'
        safety_settings = [
            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": thresh
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": thresh
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": thresh
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": thresh
            },
        ]
        client = genai.GenerativeModel(
            model_name=name,
            generation_config=generation_config,
            safety_settings=safety_settings,
        )
    elif "llama" in name.lower() or "mistral" in name.lower() or "mixtral" in name.lower():
        logging.info(f"Loading TogetherAI API client: key={configs['key']}, server={configs['server']}")
        if configs["server"] != "":
            configs["model_name"] = configs["model_name"].replace("8x7B", "8X7B")
            openai_api_key = "EMPTY"
            openai_api_base = f"http://{configs['server']}:8000/v1"
            client = openai.OpenAI(
                api_key=openai_api_key,
                base_url=openai_api_base,
                timeout=httpx.Timeout(6000.0),
            )
        else:
            if configs["key"] == "junlin":
                client = openai.OpenAI(
                    base_url="https://api.together.xyz/v1",
                    api_key=os.getenv("TOGETHER_API"),
                    timeout=httpx.Timeout(3000.0, read=600.0, write=600.0, connect=600.0),
                    max_retries=0,
                )
            else:
                client = openai.OpenAI(
                    base_url="https://api.together.xyz/v1",
                    api_key=os.getenv("BEN_TOGETHER_API"),
                    timeout=httpx.Timeout(3000.0, read=600.0, write=600.0, connect=600.0),
                    max_retries=0,
                )
    return [client, None]
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        # logging.debug("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    if model == "gpt-3.5-turbo":
        # logging.debug("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
    elif model == "gpt-4":
        # logging.debug("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
        return num_tokens_from_messages(messages, model="gpt-4-0314")
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        tokens_per_name = -1  # if there's a name, the role is omitted
    elif model == "gpt-4-0314":
        tokens_per_message = 3
        tokens_per_name = 1
    else:
        raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        value = message.content
        num_tokens += len(encoding.encode(value))
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

def num_tokens_from_string(string, model="gpt-3.5-turbo-0613"):
    # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    if model in {
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k-0613",
        }:
        tokens_per_message = 3
        tokens_per_name = 1
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        tokens_per_name = -1  # if there's a name, the role is omitted
    elif "gpt-3.5-turbo" in model:
        print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
        return num_tokens_from_string(string, model="gpt-3.5-turbo-0613")
    elif "gpt-4" in model:
        print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
        return num_tokens_from_string(string, model="gpt-4-0613")
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
        )
    num_tokens = 0
    num_tokens += tokens_per_message
    num_tokens += len(encoding.encode(string))
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

def gpt_usage(prompt_tokens, completion_tokens, backend="gpt-4"):
    if backend == "gpt-4" or backend =="gpt-4-0314" or backend =="gpt-4-0613":
        cost = completion_tokens/ 1000 * 0.06 + prompt_tokens / 1000 * 0.03
    elif backend == "gpt-3.5-turbo" or backend =="gpt-3.5-turbo-0301"or backend =="gpt-3.5-turbo-0613":
        cost = (prompt_tokens) / 1000 * 0.0015 + completion_tokens/1000 * 0.002
    elif backend =="gpt-3.5-turbo-0613-16k":
        cost = (prompt_tokens) / 1000 * 0.003 + completion_tokens/1000 * 0.004
    elif backend == "text-davinci-003":
        cost = (prompt_tokens + completion_tokens) / 1000 * 0.02
    else:
        cost = 0
    return {"cost": cost}


def save_agents(agents, dir: str):
    os.makedirs(dir, exist_ok=True)
    for i, agent in enumerate(agents):
        joblib.dump(agent, os.path.join(dir, f'{i}.joblib'))