from wpt_agent import WPTAgent
from planner import LaMPilotPlanner
from lang_agent import CodeGenerationAgent, VisionCGAgent
import carla
from PIL import Image
import os


def get_entry_point():
    return "DTAgent"


class DTAgent(WPTAgent):
    def _init(self):
        super()._init()
        # self.hai_agent = CodeGenerationAgent(zero_shot=True)
        self.hai_agent = VisionCGAgent(
            model_name="gpt-4-vision-preview",
            zero_shot=True
        )

        self._near_planner = LaMPilotPlanner(
            self._vehicle,
            opt_dict={'base_min_distance': 5.0, 'debug': 256},
            map_inst=self._map
        )
        self._far_planner = LaMPilotPlanner(
            self._vehicle,
            opt_dict={'base_min_distance': 7.5, 'debug': 257},
            map_inst=self._map
        )

        self.agent_message = ""
        # self.save_path = "data/output"
        self.save_path = ""
        if self.save_path and not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def run_step(self, input_data, timestamp):
        if not self.initialized:
            self._init()
        tick_data = self.tick(input_data)

        # TODO: Improve the instruction generator
        last_instruction = self._instruction_planner.command2instruct(
            self.town_id, tick_data, self._waypoint_planner.route)
        if self.current_instruction != last_instruction:
            self.current_instruction = last_instruction
            self.hai_agent.reset(
                command=last_instruction,
                image=tick_data['rgb_front'],
            )
            policy = self.hai_agent.step()
            self._near_planner.execute(policy)
            self._far_planner.execute(policy)

        near_node, near_cmd = self._near_planner.run_step()
        far_node, far_cmd = self._far_planner.run_step()
        self.near_cmd = near_cmd
        self.far_cmd = far_cmd
        self.agent_message = self._near_planner._agent_say

        # TODO: Improve the controller
        steer, throttle, brake, target_speed = self._get_control(near_node, far_node, tick_data)
        control = carla.VehicleControl()
        control.steer = steer
        control.throttle = throttle
        control.brake = float(brake)

        display_data = self._prepare_display_data(tick_data, timestamp, control)
        surface = self._hud.render(display_data)
        tick_data['display'] = surface

        if self.save_path:
            self.save(tick_data)

        return control

    def save(self, tick_data):
        frame = (self.step - 20)
        Image.fromarray(tick_data['display']).save(f"{self.save_path}/{frame:05d}.png")