import ast
import base64
import re
import astunparse
import numpy as np

from langchain.chat_models import ChatOpenAI
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage, ChatMessage, BaseMessage
from langchain.llms import Replicate, VertexAI
import dt.utils as U
from PIL import Image
import io


class CodeGenerationAgent:
    llms = {
        "gpt-3.5-turbo": (ChatOpenAI, "gpt-3.5-turbo"),
        "gpt-4": (ChatOpenAI, "gpt-4"),
        "gpt-4-turbo-preview": (ChatOpenAI, "gpt-4-turbo-preview"),
    }

    def __init__(self,
                 model_name: str = "gpt-4-turbo-preview",
                 temperature: float = 0.0,
                 request_timout: int = 120,
                 zero_shot: bool = True,
                 ):
        if model_name not in self.llms:
            raise RuntimeError(f"Unknown model name: {model_name}")

        llm, model = self.llms[model_name]
        if llm is ChatOpenAI:
            self.llm = ChatOpenAI(
                model_name=model,
                temperature=temperature,
                request_timeout=request_timout,
            )
        else:
            raise RuntimeError("Unknown LLM")
        self.zero_shot = zero_shot

    def render_system_message(self):
        system_template = U.load_prompt(f"cg_template_{'zs' if self.zero_shot else 'fs'}")
        apis = U.load_apis()
        response_format = U.load_prompt("cg_response_format")
        system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
        system_message = system_message_prompt.format(
            apis=apis,
            response_format=response_format,
        )
        assert isinstance(system_message, SystemMessage)
        return system_message

    @staticmethod
    def render_human_message(
            command: str = "",
            context_info: str = "",
    ):
        message = ""
        if command == "" or command is None:
            raise RuntimeError("Command is empty.")
        message += f"Command: {command}\n"
        message += f"Context Info: {context_info}\n\n"
        return HumanMessage(content=message)

    @staticmethod
    def process_ai_message(message):
        if isinstance(message, BaseMessage):
            message = message.content
        elif isinstance(message, str):
            pass
        else:
            raise RuntimeError("Unknown message type")

        code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)  # Extract the code from the message
        code = "\n".join(code_pattern.findall(message))

        # Parse the code into an AST
        try:
            parsed_code = ast.parse(code)
        except SyntaxError as e:
            print(f"SyntaxError: {e}")
            return {
                "program_code": "",
                "program_name": "",
                "exec_code": "",
            }

        functions = CodeGenerationAgent.analyze_ast(parsed_code)
        # assert len(functions) == 1, "Expected only one function in the code"
        if len(functions) != 1:
            print(f"Waring: Expected only one function in the code, but got {len(functions)}")
        if len(functions) == 0:
            print(f"Waring: No function in the code")
            return {
                "program_code": "",
                "program_name": "",
                "exec_code": "",
            }
        main_function = functions[-1]
        exec_code = f"policy = {main_function['name']}()"

        parsed_message = {
            "program_code": main_function["code"],
            "program_name": main_function["name"],
            "exec_code": exec_code
        }

        return parsed_message

    def reset(self, command: str, context_info: str = ""):
        self.command = command
        self.code = {}
        system_message = self.render_system_message()

        # print(f"\033[32m**** CG Agent system message****\n{system_message.content}\033[0m")
        human_message = self.render_human_message(
            command=command,
            context_info=context_info,
        )
        self.messages = [system_message, human_message]
        print(f"\033[32m****CG Agent human message****\n{human_message.content}\033[0m")
        assert len(self.messages) == 2
        self.conversations = []
        return self.messages

    def step(self):
        if isinstance(self.llm, ChatOpenAI):
            ai_message = self.llm(self.messages)
        else:
            raise RuntimeError("Unknown LLM")

        if isinstance(ai_message, BaseMessage):
            ai_message = ai_message.content
        print(f"\033[34m****CG Agent ai message****\n{ai_message}\033[0m")
        self.conversations.append((self.messages[0].content, self.messages[1].content, ai_message))
        parsed_result = self.process_ai_message(ai_message)
        assert isinstance(parsed_result, dict)
        self.code.update({
            "program_code": parsed_result["program_code"],
            "program_name": parsed_result["program_name"],
        })
        ret_code = {
            'reused_code': "",
            'new_code': parsed_result["program_code"] + "\n" + parsed_result["exec_code"]
        }
        return ret_code

    @staticmethod
    def analyze_ast(node):
        # List to hold information about functions
        functions_info = []

        for subnode in ast.walk(node):
            if isinstance(subnode, ast.FunctionDef):
                func_name = subnode.name
                func_code = astunparse.unparse(subnode)

                functions_info.append(
                    {
                        "name": func_name,
                        "code": func_code,
                    }
                )

        return functions_info


class VisionCGAgent(CodeGenerationAgent):
    llms = {
        "gpt-4-vision-preview": (ChatOpenAI, "gpt-4-vision-preview"),
    }

    def __init__(self,
                 model_name: str = "gpt-4-turbo-preview",
                 temperature: float = 0.0,
                 request_timout: int = 120,
                 zero_shot: bool = True,
                 ):
        if model_name not in self.llms:
            raise RuntimeError(f"Unknown model name: {model_name}")

        llm, model = self.llms[model_name]
        if llm is ChatOpenAI:
            self.llm = ChatOpenAI(
                model=model,
                max_tokens=1024,
            )
        else:
            raise RuntimeError("Unknown LLM")
        self.zero_shot = zero_shot

    def render_system_message(self):
        system_template = U.load_prompt(f"vcg_template_{'zs' if self.zero_shot else 'fs'}")
        apis = U.load_apis()
        response_format = U.load_prompt("cg_response_format")
        system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
        system_message = system_message_prompt.format(
            apis=apis,
            response_format=response_format,
        )
        assert isinstance(system_message, SystemMessage)
        return system_message

    # noinspection PyMethodOverriding
    def reset(self, command: str, image: np.ndarray, context_info: str = ""):
        self.command = command
        self.code = {}
        system_message = self.render_system_message()

        image = Image.fromarray(image)
        image_byte_array = io.BytesIO()
        image.save(image_byte_array, format="JPEG")
        image_encoded = base64.b64encode(image_byte_array.getvalue()).decode("utf-8")
        # print(f"\033[32m**** CG Agent system message****\n{system_message.content}\033[0m")
        human_message = self.render_human_message(
            command=command,
            image_encoded=image_encoded,
            context_info=context_info,
        )
        self.messages = [system_message, human_message]
        # print(f"\033[32m****CG Agent human message****\n{human_message.content}\033[0m")
        assert len(self.messages) == 2
        self.conversations = []
        return self.messages

    @staticmethod
    def render_human_message(
            command: str = "",
            image_encoded: str = "",
            context_info: str = "",
    ):
        message = ""
        if command == "" or command is None:
            raise RuntimeError("Command is empty.")
        message += f"Command: {command}\n"
        message += f"Context Info: {context_info}\n\n"
        return HumanMessage(content=[
            {"type": "text", "text": message},
            {"type": "image_url", "image_url": "data:image/jpeg;base64," + image_encoded},
        ])
