import random
from collections import deque
from typing import List, Tuple

import carla
import numpy as np

from agents.navigation.local_planner import RoadOption
from baselines.carla_agent.navigation.local_planner import _retrieve_options
from carla_tools.plotter import Plotter
from baselines.carla_agent.tools.misc import get_speed


class LaMPilotPlanner:
    def __init__(self, vehicle, opt_dict={}, map_inst=None):
        self._vehicle = vehicle
        self._world = self._vehicle.get_world()
        if map_inst:
            if isinstance(map_inst, carla.Map):
                self._map = map_inst
            else:
                print("Warning: Ignoring the given map as it is not a 'carla.Map'")
                self._map = self._world.get_map()
        else:
            self._map = self._world.get_map()

        self._min_waypoint_queue_length = 10
        self._sampling_radius = 2.0
        self._base_min_distance = 5.0
        self._max_distance = 50.0
        self._debug = None

        if opt_dict:
            if 'base_min_distance' in opt_dict:
                self._base_min_distance = opt_dict['base_min_distance']
            if 'debug' in opt_dict:
                self._debug = Plotter(opt_dict['debug'])

        # Initialize the waypoints queue
        self._waypoints_queue = deque(maxlen=10000)
        self._reset_waypoints_queue()

        self.road_option_queue = deque(maxlen=1)
        self._reset_policy()
        self._agent_say = ""

    def _reset_waypoints_queue(self):
        self._waypoints_queue.clear()
        current_waypoint = self._map.get_waypoint(self._vehicle.get_location())
        self.target_waypoint, self.target_road_option = current_waypoint, RoadOption.LANEFOLLOW
        self._waypoints_queue.append((self.target_waypoint, self.target_road_option))

    def reset_vehicle(self):
        """Reset the ego-vehicle"""
        self._vehicle = None

    def run_step(self):
        """
        :return: the next waypoint and road option
        """
        if self._debug:
            self._debug.clear()

        # Add more waypoints too few in the horizon
        if len(self._waypoints_queue) < self._min_waypoint_queue_length:
            self._compute_next_waypoints(k=self._min_waypoint_queue_length)

        if len(self._waypoints_queue) == 1:
            return self._waypoints_queue[0]

        # Purge the queue of obsolete waypoints
        veh_location = self._vehicle.get_location()
        # vehicle_speed = get_speed(self._vehicle) / 3.6  # m/s
        # self._min_distance = self._base_min_distance + self._distance_ratio * vehicle_speed
        self._min_distance = self._base_min_distance

        to_pop = 0
        farthest_in_range = -np.inf
        cumulative_distance = 0.0

        for i in range(1, len(self._waypoints_queue)):
            if cumulative_distance > self._max_distance:
                break

            cumulative_distance += self._waypoints_queue[i][0].transform.location.distance(
                self._waypoints_queue[i - 1][0].transform.location)
            distance = self._waypoints_queue[i][0].transform.location.distance(veh_location)

            if self._min_distance >= distance > farthest_in_range:
                farthest_in_range = distance
                to_pop = i

            if self._debug:
                # White for the entire route
                r = 255 * int(distance > self._min_distance)
                g = 255 * int(self._waypoints_queue[i][1].value == 4)  # 4: RoadOption.LANEFOLLOW
                b = 255
                self._debug.dot(self._loc2gps(veh_location),
                                self._waypoint2gps(self._waypoints_queue[i][0]),
                                (r, g, b))

        for _ in range(to_pop):
            if len(self._waypoints_queue) > 2:
                self._waypoints_queue.popleft()

        if self._debug:
            self._debug.dot(self._loc2gps(veh_location),
                            self._waypoint2gps(self._waypoints_queue[0][0]),
                            (0, 255, 0))  # Green for the current waypoint
            self._debug.dot(self._loc2gps(veh_location),
                            self._waypoint2gps(self._waypoints_queue[1][0]),
                            (255, 0, 0))  # Red for the next waypoint
            self._debug.dot(self._loc2gps(veh_location),
                            self._loc2gps(veh_location),
                            (0, 0, 255))  # Blue for the vehicle's position
            self._debug.show()

        result = self._waypoints_queue[1]
        return self._waypoint2gps(result[0]), result[1]

        # return self._waypoints_queue[1]

    def _compute_next_waypoints(self, k=1):
        """
        Add new waypoints to the trajectory queue.
        :param k: how many waypoints to compute
        """
        available_entries = self._waypoints_queue.maxlen - len(self._waypoints_queue)
        k = min(k, available_entries)

        for _ in range(k):
            last_waypoint = self._waypoints_queue[-1][0]
            next_waypoints = list(last_waypoint.next(self._sampling_radius))

            if len(next_waypoints) == 0:
                # If the list of waypoints is empty, the vehicle is probably at a dead end
                break
            elif len(next_waypoints) == 1:
                # only one option available ==> lane-following
                next_waypoint = next_waypoints[0]
                road_option = RoadOption.LANEFOLLOW
            else:
                road_options_list = _retrieve_options(next_waypoints, last_waypoint)
                road_options_value_list = [road_option.value for road_option in road_options_list]
                if self.road_option_queue:
                    # road_option = self.road_option_queue.popleft()
                    desired_road_option = self.road_option_queue[0]
                    if desired_road_option.value not in road_options_value_list:
                        print(
                            f"\033[31m Warning: Invalid road option {desired_road_option}, randomly choose from {road_options_list} \033[0m")
                        road_option_value = random.choice(road_options_value_list)
                    else:
                        road_option_value = desired_road_option.value
                else:
                    road_option_value = random.choice(road_options_value_list)
                    print(
                        f"\033[31m Warning: Randomly choose road option from {road_options_list} \033[0m")

                next_waypoint = next_waypoints[road_options_value_list.index(road_option_value)]

            self._waypoints_queue.append((next_waypoint, road_option))

    def _set_global_plan(self, current_plan):
        if len(current_plan) < 2:
            print("Warning: The global plan has less than two waypoints")
            return
        self._waypoints_queue.clear()

        # Remake the waypoints queue if the new plan has a higher length than the queue
        new_plan_length = len(current_plan) + len(self._waypoints_queue)
        if new_plan_length > self._waypoints_queue.maxlen:
            new_waypoint_queue = deque(maxlen=new_plan_length)
            for wp in self._waypoints_queue:
                new_waypoint_queue.append(wp)
            self._waypoints_queue = new_waypoint_queue

        for elem in current_plan:
            self._waypoints_queue.append(elem)

    def _lane_change(self,
                     direction,
                     same_lane_time=1,
                     other_lane_time=2,
                     lane_change_time=2,
                     lane_changes=1,
                     ):
        """
        Changes the path so that the vehicle performs a lane change.
        Use 'direction' to specify either a 'left' or 'right' lane change,
        and the other 3 fine tune the maneuver
        """
        speed = get_speed(self._vehicle) / 3.6  # m/s
        path = self._generate_lane_change_path(
            self._map.get_waypoint(self._vehicle.get_location()),
            direction=direction,
            distance_same_lane=same_lane_time * speed,
            distance_other_lane=other_lane_time * speed,
            lane_change_distance=lane_change_time * speed,
            lane_changes=lane_changes,
        )
        if not path:
            print("WARNING: Ignoring the lane change as no path was found")
        self._set_global_plan(path)

    @staticmethod
    def _generate_lane_change_path(waypoint,
                                   direction='left',
                                   distance_same_lane=10,
                                   distance_other_lane=25,
                                   lane_change_distance=25,
                                   check=True,
                                   lane_changes=1,
                                   step_distance=2
                                   ):
        """
        This methods generates a path that results in a lane change.
        Use the different distances to fine-tune the maneuver.
        If the lane change is impossible, the returned path will be empty.

        :param waypoint: A carla.Waypoint is a 3D-directed point in the CARLA world corresponding to an OpenDRIVE lane.
        Each waypoint contains a carla.Transform which states its location on the map and the orientation of the lane
        containing it. The variables road_id,section_id,lane_id and s correspond to the OpenDRIVE road.
        The id of the waypoint is constructed from a hash combination of these four values.
        :param direction:
        :param distance_same_lane:
        :param distance_other_lane:
        :param lane_change_distance:
        :param check:
        :param lane_changes:
        :param step_distance:
        :return:
        """
        distance_same_lane = max(0.1, distance_same_lane)
        distance_other_lane = max(0.1, distance_other_lane)
        lane_change_distance = max(0.1, lane_change_distance)

        plan: List[Tuple[carla.Waypoint, RoadOption]] = [(waypoint, RoadOption.LANEFOLLOW)]  # initial waypoint
        option = RoadOption.LANEFOLLOW

        # Same lane
        distance = 0
        while distance < distance_same_lane:
            # creates a list of waypoints within an approximate distance, d, in the direction of the lane.
            # The list contains one waypoint for each possible deviation.
            next_wps = plan[-1][0].next(step_distance)
            if not next_wps:
                return []
            next_wp = next_wps[0]
            distance += next_wp.transform.location.distance(plan[-1][0].transform.location)
            plan.append((next_wp, RoadOption.LANEFOLLOW))

        if direction == 'left':
            option = RoadOption.CHANGELANELEFT
        elif direction == 'right':
            option = RoadOption.CHANGELANERIGHT
        else:
            print(f"\033[31m Warning: Invalid direction {direction} \033[0m")
            return []

        lane_changes_done = 0
        lane_change_distance = lane_change_distance / lane_changes

        # Lane change
        while lane_changes_done < lane_changes:
            # Move forward
            next_wps = plan[-1][0].next(lane_change_distance)
            if not next_wps:
                return []
            next_wp = next_wps[0]

            # Get the side lane
            if direction == 'left':
                if check and str(next_wp.lane_change) not in ['Both', 'Left']:
                    print("Warning: Left lane change not possible")
                    return []
                # return the equivalent waypoint in an adjacent lane, if one exists.
                # A lane change maneuver can be made by finding the next waypoint
                # to the one on its right/left lane, and moving to it.
                side_wp = next_wp.get_left_lane()
            else:
                if check and str(next_wp.lane_change) not in ['Both', 'Right']:
                    print("Warning: Right lane change not possible")
                    return []
                side_wp = next_wp.get_right_lane()

            if not side_wp or side_wp.lane_type != carla.LaneType.Driving:
                print("Warning: No side lane found")
                return []

            # Update the plan
            plan.append((side_wp, option))
            lane_changes_done += 1

        # Other lane
        distance = 0
        while distance < distance_other_lane:
            next_wps = plan[-1][0].next(step_distance)
            if not next_wps:
                return []
            next_wp = next_wps[0]
            distance += next_wp.transform.location.distance(plan[-1][0].transform.location)
            plan.append((next_wp, RoadOption.LANEFOLLOW))

        return plan

    @staticmethod
    def _waypoint2gps(waypoint):
        loc = waypoint.transform.location
        return LaMPilotPlanner._loc2gps(loc)

    @staticmethod
    def _loc2gps(location):
        return np.array([-location.y, location.x])

    # TODO: Separate the following methods into a new class
    def execute(self, code: dict):
        apis = [
            "say",
            "autopilot",
            "change_lane_left",
            "change_lane_right",
            "turn_left_at_next_intersection",
            "go_straight_at_next_intersection",
            "turn_right_at_next_intersection",
        ]
        apis = {api: getattr(self, f"_api_{api}") for api in apis}
        exec(code['reused_code'], apis)
        local_vars = {}
        try:
            exec(code['new_code'], apis, local_vars)
            self.policy = local_vars.get("policy", iter([]))
        except Exception as e:
            print(f"\033[31m Error in execute: {e} \033[0m")
            self._reset_policy()

    def _reset_policy(self):
        self.policy = iter([])

    #################### APIs ####################

    def _api_change_lane_left(self, times=1):
        self._lane_change('left', lane_changes=times)

    def _api_change_lane_right(self, times=1):
        self._lane_change('right', lane_changes=times)

    def _api_turn_left_at_next_intersection(self):
        self._reset_waypoints_queue()
        self.road_option_queue.append(RoadOption.LEFT)

    def _api_go_straight_at_next_intersection(self):
        self._reset_waypoints_queue()
        self.road_option_queue.append(RoadOption.STRAIGHT)

    def _api_turn_right_at_next_intersection(self):
        self._reset_waypoints_queue()
        self.road_option_queue.append(RoadOption.RIGHT)

    def _api_autopilot(self):
        pass

    # noinspection PyMethodMayBeStatic
    def _api_say(self, text: str):
        print(f"Agent says: {text}")
        self._agent_say = text


if __name__ == "__main__":
    road_option = 1
    road_option_list = [RoadOption.LANEFOLLOW, RoadOption.LEFT, RoadOption.RIGHT, RoadOption.STRAIGHT]
    road_option_list = [road_option.value for road_option in road_option_list]
    print(road_option)
    print(road_option_list)
    print(road_option in road_option_list)
