# coding=utf8
import os
import yaml
import json
import time
import requests
import sseclient
from venus_api_base.http_client import HttpClient
from venus_api_base.config import Config


class VenusClient(object):

    def __init__(self, use_sys="你善于分析小说和剧本等文案，回答各种分析推理的问题，以及推理其中的人物关系。", model=""):
        # Venus support models: gpt-3.5-turbo、gpt-3.5-turbo-1106、gpt-3.5-turbo-16k、gpt-4、gpt-4-1106-preview、gpt-4-0125-preview、gpt-4-vision-preview
        """
        Prices:
            gpt-3.5-turbo-1106: Input: $1.00/1M tokens, Output: $2.00/1M tokens
            gpt-3.5-turbo-0613: Input: $1.50/1M tokens, Output: $2.00/1M tokens
            gpt-3.5-turbo-16k: Input: $3.00/1M tokens, Output: $4.00/1M tokens
            gpt-3.5-turbo-0301: Input: $1.50/1M tokens, Output: $2.00/1M tokens
            gpt-3.5-turbo-0125: Input: $0.50/1M tokens, Output: $1.50/1M tokens
            gpt-4: Input: $30.00/1M tokens, Output: $60.00/1M tokens
            gpt-4-32k: Input: $60.00/1M tokens, Output: $120.00/1M tokens
            gpt-4-0125-preview: Input: $10.00/1M tokens, Output: $30.00/1M tokens
            gpt-4-1106-preview: Input: $10.00/1M tokens, Output: $30.00/1M tokens
            gpt-4-1106-vision-preview: Input: $10.00/1M tokens, Output: $30.00/1M tokens
        """
        with open("/cfs/cfs-0s927vn5/haojinghuang/codes/human_preference_alignment_algorithms/src/user_config.yaml") as f:
            cfg = yaml.safe_load(f)["venus_api"]
        self.client = HttpClient(secret_id=cfg["sid"], secret_key=cfg["skey"], config=Config(read_timeout=400))
        self.group_id = cfg["group_id"]
        if model:
            self.model = model
        else:
            self.model = cfg["model"]
        # print(self.list_session())
        self.session_id = None
        # self.new_session()
        # print(self.session_id)
        self.use_sys = use_sys
        print("venus chatgpt inited. New model.")
        print(f"The model is: {model}")

    def __del__(self):
        try:
            self.del_session()
        except:
            pass

    def list_session(self):
        url = f"http://v2.open.venus.oa.com/chat/session/list?appGroupId={self.group_id}&model={self.model}&pageSize=100&pageNum=1"
        # print('url:', url)
        resp = self.client.get(url=url, header={"Content-Type": "application/json"})
        try:
            session_list = [content["sessionId"] for content in resp["data"]["content"]]
        except:
            session_list = []
        return session_list

    def clean_session(self):
        print("删除所有会话")
        session_list = self.list_session()

        for delete_num, session_id in enumerate(session_list):
            print(session_id)
            self.del_session(session_id)
            # if delete_num > 4:
            #     break

    def del_session(self, session_id=None):
        if session_id is None:
            session_id = self.session_id

        if session_id is None:
            return

        url = "http://v2.open.venus.oa.com/chat/session/delete"
        data = {"sessionId": session_id}
        # print('url:', url)
        # print(data)
        resp = self.client.post(url=url, header={"Content-Type": "application/json"}, body=json.dumps(data))
        # print(resp)
        print(f"删除会话: {session_id}")
        return

    def new_session(self):
        self.del_session()
        url = "http://v2.open.venus.oa.com/chat/session/save"
        data = {"appGroupId": self.group_id, "model": self.model}
        resp = self.client.post(url=url, header={"Content-Type": "application/json"}, body=json.dumps(data))
        if resp["code"] == 4001:
            self.clean_session()
            resp = self.client.post(url=url, header={"Content-Type": "application/json"}, body=json.dumps(data))
        print(resp)
        self.session_id = resp["data"]["sessionId"]
        print(f"创建新会话: {self.session_id}")

    @staticmethod
    def with_requests(url, headers, body):
        """Get a streaming response for the given event feed using requests."""
        return requests.post(url, stream=True, headers=headers, data=json.dumps(body))

    def __call__(self, dialog: str, max_context=0):
        if self.session_id is None:
            self.new_session()
        body = {"sessionId": self.session_id, "request": dialog, "maxContext": max_context}

        url = "http://v2.open.venus.oa.com"
        path = "/chat/sse"
        timestamp = str(int(time.time()))
        sign = self.client.gen_venus_sign_with_timestamp(path, "", json.dumps(body), timestamp)

        headers = {
            "Content-Type": "application/json",
            "Venusopenapi-Request-Timestamp": timestamp,
            "Venusopenapi-Authorization": sign,
            "Venusopenapi-Secret-Id": "",
        }

        response = self.with_requests(url + path, headers, body)
        # print(response.status_code)
        sc = sseclient.SSEClient(response)
        result = {}
        for event in sc.events():
            # print(event.event)
            # print(event.data)
            if event.event != "[FIN]":
                continue
            result = json.loads(event.data)
        # print(result)
        res = result.get("response", "")
        # time.sleep(1)
        if res is None or res == "":
            raise Exception("call chatgpt failed")
        else:
            return res

    def single_chat(self, query):
        # 注意此接口不能从venus拉取历史记录
        
        header = {
            "Content-Type": "application/json",
        }

        if self.use_sys:
            msg = [
                    {"role": "system", "content": self.use_sys},
                    {"role": "user", "content": query}
            ]
        else:
            msg = [
                {"role": "user", "content": query}
            ]

        body = {
            # 需要替换成自己的 appGroupId
            "appGroupId": self.group_id,
            "model": self.model,
            # # 不用自定义上下文
            # "request": query,
            
            # 自定义上下文
            "messages": msg,
            # "messages": context
            "temperature": 0.2,
            "top_p": 0.1,
            # "max_tokens": 3500,
            "presence_penalty": 0,
            "frequency_penalty": 0,
        }

        try:
            ret = self.client.post("http://v2.open.venus.oa.com/chat/single", header=header, body=json.dumps(body))
            #print("Venus ret: {}".format(ret))
        except Exception as e:
            print('venus err: ', e)
            return ""
        # print(ret)
        if ret["code"] != 0:
            print(ret)
            return ""
        else:
            if ret.get("data", {}).get("response", "") == None:
                return ""
            else:
                return ret.get("data", {}).get("response", "")


if __name__ == "__main__":
    client = VenusClient()
    query = "介绍一下深圳"
    start = time.time()
    out = client.single_chat(query)
    end = time.time()
    print("函数执行时间：{:.2f} 秒".format(end-start))

    print(out)

    del client
    # for i in range(5):
    #     start = time.time()
    #     query = "介绍一下深圳"
    #     print(client.single_chat(query))

    #     end = time.time()
    #     print("函数执行时间：{:.2f} 秒".format(end-start))
    # del client