from .base_environment import BaseEnvironment
import sys
import os

curr_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(f'{curr_dir}/../')

from collections import defaultdict
from  unity_simulator import comm_unity as comm_unity
# from . import utils as utils_environment
from evolving_graph import utils
import atexit
import random
import pdb
import ipdb
import random
import cv2
import json
import numpy as np
from PIL import Image

def REL_EQUAL(rel1, rel2):
    if rel1[0].lower() == rel2[0].lower() and rel1[1].lower() == rel2[1].lower() and rel1[2].lower() == rel2[2].lower():
        return True
    else:
        return False

class UnityEnvironment(BaseEnvironment):
    def __init__(self,
                num_agents=1,
                max_episode_length=500,
                observation_types=None,
                use_editor=False,
                base_port=8080,
                port_id=0,
                executable_args={},
                recording_options=None,
                seed=123,
                ):
        self.seed = seed
        self.prev_reward = 0.
        self.rnd = random.Random(seed)
        np.random.seed(seed)

        if recording_options is None:
            self.recording_options = {'recording': False,
                                    'output_folder': None,
                                    'file_name_prefix': None,
                                    'cameras': 'PERSON_FROM_BACK',
                                    'modality': 'normal'}
        else:
            self.recording_options = recording_options

        self.steps = 0
        self.env_id = 0
        self.max_ids = {}

        self.in_graph_states = defaultdict(list)
        self.num_agents = num_agents
        self.max_episode_length = max_episode_length
        self.base_port = base_port
        self.port_id = port_id
        self.executable_args = executable_args

        # Observation parameters
        self.num_camera_per_agent = 6
        self.CAMERA_NUM = [2,3]
        self.default_image_width = 640
        self.default_image_height = 480

        if observation_types is not None:
            self.observation_types = observation_types
        else:
            self.observation_types = ['partial' for _ in range(num_agents)]

        self.agent_info = {
            0: 'Chars/Male1'
        }

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True
        self.rooms = None
        self.id2node = None
        self.num_static_cameras = None
        self.success_condition = None
        self.rgb_to_instance = None
        self.prev_observable_objects = {}
        self.visible_objects_cache = {}
        self.image_cache = None
        
        self.plan = []
        self.on = []
        
        
        if use_editor:
            # Use Unity Editor
            self.port_number = 8080
            self.comm = comm_unity.UnityCommunication()
        else:
            # Launch the executable
            self.port_number = self.base_port + port_id
            # ipdb.set_trace()
            self.comm = comm_unity.UnityCommunication()

        atexit.register(self.close)
        # self.reset()

    def close(self):
        self.comm.close()

    def relaunch(self):
        self.comm.close()
        self.comm = comm_unity.UnityCommunication()

    def set_task(self, success_condition):
        """
        :param success_condition: List[Tuple[str]]
        success condition for graph edges
        e.g. if success_condition is [('tv', 'is', 'on')],
        the task is success when ('tv', 'is', 'on') in self.get_graph['edges']
        :return:
        """
        self.success_condition = success_condition

    def find_nodes(self, **kwargs):
        if len(kwargs) == 0:
            return None
        else:
            k, v = next(iter(kwargs.items()))
            return [n for n in self.graph['nodes'] if n[k] == v]

    def clean_graph(self, graph):
        new_nodes = []
        for n in graph['nodes']:
            nc = dict(n)
            if 'bounding_box' in nc:
                del nc['bounding_box']
            new_nodes.append(nc)
        return {'nodes': new_nodes, 'edges': list(graph['edges'])}
    
    def remove_edges(self, n_id, fr=True, to=True):
        # n_id = n['id']
        new_edges = [e for e in self.graph['edges'] if 
                    (e['from_id'] != n_id or not fr) and (e['to_id'] != n_id or not to)]
        self.graph['edges'] = new_edges
    
  
    def add_node(self,  n):
        self.graph['nodes'].append(n)

    def add_edge(self,  fr_id, rel, to_id):
        self.graph['edges'].append({'from_id': fr_id, 'relation_type': rel, 'to_id': to_id})

    def clean_graph(self):
        new_nodes = []
        for n in self.graph['nodes']:
            nc = dict(n)
            if 'bounding_box' in nc:
                del nc['bounding_box']
            new_nodes.append(nc)
        return {'nodes': new_nodes, 'edges': list(self.graph['edges'])}


    def reward(self):
        # Define here your reward
        if self.success_condition is None:
            reward = 0
            done = False
            info = {}
        else:
            info = {'success_condition': self.success_condition}
            success = [False] * len(self.success_condition)

            reward = 0
            entire_graph = self.get_graph(mode='triples')
            agent_graph = self.get_agent_graph(mode='triples')


            for edge in entire_graph['edges']:
                for idx, suc_cond in enumerate(self.success_condition):
                    if REL_EQUAL(edge, suc_cond):
                        success[idx] = True

            for edge in agent_graph['edges']:
                for idx, suc_cond in enumerate(self.success_condition):
                    if REL_EQUAL(edge, suc_cond):
                        success[idx] = True

            for i in range(len(success)):
                if success[i]:
                    reward += 1
            if reward == len(self.success_condition):
                info['is_success'] = True
                done = True
            else:
                info['is_success'] = False
                done = False
        return reward, done, info

    def step(self, action_dict):
        
        myroom = self.get_agent_room()

        if "put in" in action_dict:
            action_dict = action_dict.replace("put in", "putin")

        script_list = self.convert_action(action_dict, self.get_visible_objects()[1])
        # print(script_list)
        if len(script_list[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script_list,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=True,
                                                        camera_mode=self.recording_options['cameras'])
            else:
                success, message = self.comm.render_script(script_list,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
                
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
        # Obtain reward
        
        self.plan.append(action_dict)
        
        reward, done, info = self.reward()
        self.steps += 1

        obs = self.get_observations()
        info['finished'] = done
        # self.prev_observable_objects = avail_objects
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom

        return obs, reward, done, info


    def baseline_step(self, action_dict):
        myroom = self.get_agent_room()

        avail_actions, avail_objects = self.actions_available()
        # avail_actions, avail_objects =ava

        script, action = self.affordance(action_dict, self.get_visible_objects()[1])   
        
        if len(script[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=False
                                                        )
            else:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
           
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
        # Obtain reward
        
        self.plan.append(action)
        
        reward, done, info = self.reward()
        self.steps += 1

        obs = self.get_observations()
        info['finished'] = done
        self.prev_observable_objects = avail_objects
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom
        info['visible'] = self.get_visible_objects()[1]

        return obs,  info, action


    def prog_step(self, action_dict):
        myroom = self.get_agent_room()
        
        if "put in" in action_dict:
            action_dict = action_dict.replace("put in", "putin")

        avail_actions, avail_objects = self.actions_available()


        script, action = self.progprompt_action(action_dict, self.get_visible_objects()[1])   
        
             
        if len(script[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=True,
                                                        camera_mode=self.recording_options['cameras'],
                                                        file_name_prefix='task_{}'.format(self.task_id),
                                                        image_synthesis=self.recording_optios['modality'])
            else:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
           
            if not success:
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
                
        # Obtain reward
        self.plan.append(action)
        
        reward, done, info = self.reward()
        self.steps += 1

        obs = self.get_observations()
        info['finished'] = done
        self.prev_observable_objects = avail_objects
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom
        info['visible'] = self.get_visible_objects()[1]

        return obs,  info, action


    def reset(self, environment_graph=None, environment_id=None, init_rooms=None,init_pos=None,clear_space=None, angle=None):
        
        """
        :param environment_graph: the initial graph we should reset the environment with
        :param environment_id: which id to start
        :param init_rooms: where to intialize the agents
        """
        self.env_id = environment_id
        # print("Resetting env", self.env_id)

        if self.env_id is not None:
            self.comm.reset(self.env_id)
        else:
            self.comm.reset()
        self.plan = []

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True

        sucess, self.graph = self.comm.environment_graph()

        if self.env_id not in self.max_ids.keys():
            max_id = max([node['id'] for node in self.graph['nodes']])
            self.max_ids[self.env_id] = max_id

        if environment_graph is None:
            # TODO: this should be modified to extend well
            # updated_graph = utils.separate_new_ids_graph(environment_graph, max_id)
            entire_graph = self.get_graph()
            classwise_obj = []
            node_ids = []
            updated_graph = {'nodes': [], 'edges': []}

            _,class_name_to_id = self.get_nodes()
            ids_to_remove = []
            if clear_space:
                ids_to_remove = self.remove_objects(clear_space)
            # ids_to_remove = self.remove_objects(tgt_place)
                
            ##########
            for node in entire_graph['nodes']:
                if ((node['category'] not in ['Rooms', 'Walls', 'Ceiling', 'Lamps', 'Characters', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows'] )):
                    if node['id'] not in ids_to_remove:
                        classwise_obj.append(node['class_name'])
                        updated_graph['nodes'].append(node)
                        node_ids.append(node['id'])
                else:
                    updated_graph['nodes'].append(node)
                    node_ids.append(node['id'])
            for edge in entire_graph['edges']:
                if edge['from_id'] in node_ids and edge['to_id'] in node_ids:
                    updated_graph['edges'].append(edge)

    
            if self.env_id == 0:
                kitchentable = self.find_nodes(class_name='kitchentable')


                self.add_node({'class_name': 'cat', 
                   'category': 'Animals', 
                   'id': 1000, 
                   'properties': [], 
                   'states': []})
                self.add_edge(1000, 'ON', kitchentable[0]['id'])
 

            success, m = self.comm.expand_scene(entire_graph)
        else:
            updated_graph = environment_graph
            success, m = self.comm.expand_scene(updated_graph)

        if not success:
            print("Error expanding scene")
            # pdb.set_trace()
            # return None
        self.num_static_cameras = self.comm.camera_count()[1]

        
        if init_rooms is None or init_rooms[0] not in ['kitchen', 'bedroom', 'livingroom', 'bathroom']:
            rooms = self.rnd.sample(['kitchen', 'bedroom', 'livingroom', 'bathroom'], 2)
        else:
            rooms = list(init_rooms)
        # print(self.comm.character_cameras())
        for i in range(self.num_agents):
            if i in self.agent_info:
                if init_pos:
                    self.comm.add_character(self.agent_info[i], position=init_pos)
                else:
                    self.comm.add_character(self.agent_info[i], initial_room=rooms[i])
            else:
                self.comm.add_character()

        self.comm.update_camera(self.num_static_cameras + self.CAMERA_NUM[0] , field_view=80, position=[0, 1.6, 0])

        _, self.init_unity_graph = self.comm.environment_graph()

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True

        self.prev_observable_objects = {}
        graph = self.get_graph()
        self.rooms = {node['id']: node['class_name'] for node in graph['nodes'] if node['category'] == 'Rooms'}  ## class id : room
        self.room_ids = {v:k for k,v in self.rooms.items()}

        self.id2node = {node['id']: node for node in graph['nodes']}
        self.objinroom = {edge['from_id']: edge['to_id'] for edge in graph['edges'] if edge['relation_type'] == 'INSIDE' and edge['from_id'] != 1 and edge['to_id'] in self.rooms}
        
        self.get_rgb_to_instance()

        self.in_graph_states = defaultdict(list)
        obs = self.get_observations()
        self.steps = 0
        self.prev_reward = 0.
        return obs

    def convert_action(self, action, avail_objects):
        avail_objects = {v:k for k,v in avail_objects.items()}
        action_list = action.split()
        

        if len(action_list) == 1:
            current_script = ['<char{}> [{}]'.format(0, action)]
            if action_list[0] == 'turnright' or action_list[0] == 'turnleft':
                current_script.append('<char{}> [{}]'.format(0, action))
                current_script.append('<char{}> [{}]'.format(0, action))
            elif action_list[0] == 'walkforward':
                current_script.append('<char{}> [{}]'.format(0, action))


        elif len(action_list) == 2:

            if action_list[1] in avail_objects.keys():
                current_script = ['<char{}> [{}] <{}> ({})'.format(0, action_list[0], action_list[1], avail_objects[action_list[1]])]

            if action_list[0] == 'walk':
                nodes,_ = self.get_nodes()
                # print(nodes.values())
                obj_to_id = {obj:id for id,obj in nodes.items()}
                objid = obj_to_id[action_list[1]]
                current_script = ['<char{}> [{}] <{}> ({})'.format(0,'walk',action_list[1],objid)]
            
            
            elif action_list[0] == 'find':
                current_script = []
                nodes,_ = self.get_nodes()
                # print(nodes.values())
                obj_to_id = {obj:id for id,obj in nodes.items()}
                
                objid = obj_to_id[action_list[1]]
                action1 = '<char{}> [{}] <{}> ({})'.format(0,'walk',action_list[1],objid)
                
                current_script.append(action1)
                
                self.on.append(objid)
                
                try:
                    action2 = '<char{}> [{}] <{}> ({})'.format(0,'lookat',action_list[1],objid)

                    current_script.append(action2)
                except: pass
                

            elif action_list[0] == 'grab':
                if action_list[1] in avail_objects.keys():
                    current_script = ['<char{}> [{}] <{}> ({})'.format(0,'grab',action_list[1], avail_objects[action_list[1]])]
                else:  current_script = ['<char0> [no_action]']

            elif action_list[0] == 'sit':
                if action_list[1] in avail_objects.keys():
                    current_script = ['<char{}> [{}] <{}> ({})'.format(0,'sit',action_list[1], avail_objects[action_list[1]])]
                else:  current_script = ['<char0> [no_action]']
                
            
            elif action_list[0] == 'open':
                if action_list[1] in avail_objects.keys():
                    current_script = ['<char{}> [{}] <{}> ({})'.format(0,'open',action_list[1], avail_objects[action_list[1]])]
                else:  current_script = ['<char0> [no_action]']
                
            elif action_list[0] == 'close':
                if action_list[1] in avail_objects.keys():
                    current_script = ['<char{}> [{}] <{}> ({})'.format(0,'close',action_list[1], avail_objects[action_list[1]])]
                else:  current_script = ['<char0> [no_action]']

            else:
                current_script = ['<char0> [no_action]']


        elif len(action_list) == 3:
            # put 
            if action_list[1] in avail_objects.keys() and action_list[2] in avail_objects.keys():
                current_script = ['<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, action_list[0], action_list[1], avail_objects[action_list[1]],action_list[2], avail_objects[action_list[2]])]
            
            if action_list[0] == 'put':
                print(action_list)
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes,_ = self.get_nodes()
                        id_to_obj = {id:obj for id,obj in nodes.items()}    
                        if item_id in id_to_obj.keys():
                            item = id_to_obj[item_id]
                            
                            grab_dict[item] = item_id
                        

                if action_list[1] in grab_dict.keys() and action_list[2] in avail_objects.keys():
                    current_script = []
                    action1 = '<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, 'put', action_list[1], grab_dict[action_list[1]], action_list[2], avail_objects[action_list[2]])
                    current_script.append(action1)
                    # action3 =  '<char{}> [{}] <{}> ({})'.format(0,'lookat',action_list[1], grab_dict[action_list[1]])
                    # current_script.append(action3)

                else: current_script = ['<char0> [no_action]']
                
                
            elif action_list[0] == 'putin':
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes,_ = self.get_nodes()
                        id_to_obj = {id:obj for id,obj in nodes.items()}    
                        if item_id in id_to_obj.keys():
                            item = id_to_obj[item_id]
                            
                        grab_dict[item] = item_id
                
                
                
                obj2 = self.find_nodes(class_name=action_list[-1])
                obj2_ids = [obj_state["states"] for obj_state in obj2]

                if ["OPEN"] in obj2_ids:
                    put_avail_objects = avail_objects
                else:
                    put_avail_objects = avail_objects.remove(action_list[2])
       
                                
                if action_list[1] in grab_dict.keys() and action_list[2] in put_avail_objects.keys():
                    current_script = []
                    action1 = '<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, 'putin', action_list[1], grab_dict[action_list[1]], action_list[2], put_avail_objects[action_list[2]])
                    current_script.append(action1)
     
                else: current_script = ['<char0> [no_action]']

            elif action_list[0] == "switch":
                action = f"switch{action_list[1]}"
                current_script = ['<char{}> [{}] <{}> ({})'.format(0, action, action_list[2], avail_objects[action_list[2]])]
            else:
                current_script = ['<char0> [no_action]']
        if current_script == ['<char0> [no_action]']: 
            pass


        return current_script

    def remove_prefix(self, text):
        import re
        return re.sub(r'^\d+\.\s*', '', text)

    def affordance(self, skill_list, avail_objects):
        
        with open('../../../Task_Data/vh_config.json', 'r') as f:
            vh_config = json.load(f)        
        vh_skill = vh_config["VH_SKILL"]
    
        avail_objects = {v:k for k,v in avail_objects.items()}
        
        for action in skill_list:
            if '.' in action:
                action = self.remove_prefix(action)
                
            if 'done' in action: action = action.replace(', done', '')
            print("Action: ", action)

            if action not in vh_skill:
                print("not in vh skills")
                pass
            else:
                
                if "put in" in action:
                    action = action.replace("put in", "putin")
                action_list = action.split()
                

                if len(action_list) == 1:
                    current_script = ['<char{}> [{}]'.format(0, action)]
                    if action_list[0] == 'turnright' or action_list[0] == 'turnleft':
                        current_script.append('<char{}> [{}]'.format(0, action))
                        current_script.append('<char{}> [{}]'.format(0, action))
                    elif action_list[0] == 'walkforward':
                        current_script.append('<char{}> [{}]'.format(0, action))

                elif len(action_list) == 2:
                    print(action_list)

                    if action_list[1] in avail_objects.keys():
                        current_script = ['<char{}> [{}] <{}> ({})'.format(0, action_list[0], action_list[1], avail_objects[action_list[1]])]

                    if action_list[0] == 'walk':
                        nodes,_ = self.get_nodes()
                        # print(nodes.values())
                        obj_to_id = {obj:id for id,obj in nodes.items()}
                        objid = obj_to_id[action_list[1]]
                        current_script = ['<char{}> [{}] <{}> ({})'.format(0,'walk',action_list[1],objid)]
                    
                    elif action_list[0] == 'find':
                        current_script = []
                        nodes,_ = self.get_nodes()
                        # print(nodes.values())
                        obj_to_id = {obj:id for id,obj in nodes.items()}
                        
                        objid = obj_to_id[action_list[1]]
                        action1 = '<char{}> [{}] <{}> ({})'.format(0,'walk',action_list[1],objid)
                        
                        current_script.append(action1)
                        
                        self.on.append(objid)
                        
                        try:
                            action2 = '<char{}> [{}] <{}> ({})'.format(0,'lookat',action_list[1],objid)

                            current_script.append(action2)
                        except: pass
                        
                    elif action_list[0] == 'grab':
                        if action_list[1] in avail_objects.keys():
                            current_script = ['<char{}> [{}] <{}> ({})'.format(0,'grab',action_list[1], avail_objects[action_list[1]])]
                        else:  current_script = ['<char0> [no_action]']

                    elif action_list[0] == 'sit':
                        if action_list[1] in avail_objects.keys():
                            current_script = ['<char{}> [{}] <{}> ({})'.format(0,'sit',action_list[1], avail_objects[action_list[1]])]
                        else:  current_script = ['<char0> [no_action]']
                        
                    elif action_list[0] == 'open':
                        if action_list[1] in avail_objects.keys():
                            current_script = ['<char{}> [{}] <{}> ({})'.format(0,'open',action_list[1], avail_objects[action_list[1]])]
                        else:  current_script = ['<char0> [no_action]']
                        
                        
                    elif action_list[0] == 'close':
                        if action_list[1] in avail_objects.keys():
                            current_script = ['<char{}> [{}] <{}> ({})'.format(0,'close',action_list[1], avail_objects[action_list[1]])]
                        else:  current_script = ['<char0> [no_action]']

                    else:
                        current_script = ['<char0> [no_action]']

                elif len(action_list) == 3:
                    # put 
                    if action_list[1] in avail_objects.keys() and action_list[2] in avail_objects.keys():
                        current_script = ['<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, action_list[0], action_list[1], avail_objects[action_list[1]],action_list[2], avail_objects[action_list[2]])]
                    
                    if action_list[0] == 'put':
                        agent_graph = self.get_agent_graph()
                        grab_dict = {}
                        for agent_edge in agent_graph["edges"]:
                            if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                                item_id = agent_edge['to_id']
                                nodes,_ = self.get_nodes()
                                id_to_obj = {id:obj for id,obj in nodes.items()}    
                                if item_id in id_to_obj.keys():
                                    item = id_to_obj[item_id]
                                    
                                grab_dict[item] = item_id
                                

                        if action_list[1] in grab_dict.keys() and action_list[2] in avail_objects.keys():
                            current_script = []
                            action1 = '<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, 'put', action_list[1], grab_dict[action_list[1]], action_list[2], avail_objects[action_list[2]])
                            current_script.append(action1)
                                                
                        else: current_script = ['<char0> [no_action]']
                        
                        
                    elif action_list[0] == 'putin':
                        agent_graph = self.get_agent_graph()
                        grab_dict = {}
                        for agent_edge in agent_graph["edges"]:
                            if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                                item_id = agent_edge['to_id']
                                nodes,_ = self.get_nodes()
                                id_to_obj = {id:obj for id,obj in nodes.items()}    
                                if item_id in id_to_obj.keys():
                                    item = id_to_obj[item_id]
                                    
                                grab_dict[item] = item_id
                        
                        
                        obj2 = self.find_nodes(class_name=action_list[-1])
                        obj2_ids = [obj_state["states"] for obj_state in obj2]
                        if ["OPEN"] in obj2_ids:
                            put_avail_objects = avail_objects
                        else:
                            put_avail_objects = avail_objects.remove(action_list[2])

                                        
                        if action_list[1] in grab_dict.keys() and action_list[2] in put_avail_objects.keys():
                            current_script = []
                            action1 = '<char{}> [{}] <{}> ({}) <{}> ({})'.format(0, 'putin', action_list[1], grab_dict[action_list[1]], action_list[2], avail_objects[action_list[2]])
                            current_script.append(action1)
            
                        else: current_script = ['<char0> [no_action]']

                    elif action_list[0] == "switch":
                        action = f"switch{action_list[1]}"
                        current_script = ['<char{}> [{}] <{}> ({})'.format(0, action, action_list[2], avail_objects[action_list[2]])]


                    else:
                        current_script = ['<char0> [no_action]']
                        
                if current_script == ['<char0> [no_action]']: 
                    pass
                else:
                    break
            
        return current_script, action

    def progprompt_action(self, action, visible_objects):

        nodes, _ = self.get_nodes()
        avail_objects = {obj: id for id, obj in nodes.items()}
        
        if "put in" in action:
            _, _, obj1, obj2 = action.split()
            action = f"putin {obj1} {obj2}"
        
        action_list = action.split()
        
        if len(action_list) == 1:
            action_name = action_list[0]
            current_script = [f'<char0> [{action_name}]']
            
            if action_name in ['turnright', 'turnleft']:
                current_script.extend([f'<char0> [{action_name}]'] * 2)
            elif action_name == 'walkforward':
                current_script.append(f'<char0> [walkforward]')
        
        elif len(action_list) == 2:
            action_name, obj = action_list
            
            if obj in avail_objects:
                objid = avail_objects[obj]
                current_script = [f'<char0> [{action_name}] <{obj}> ({objid})']
                
                if action_name == 'walk':
                    current_script = [f'<char0> [walk] <{obj}> ({objid})']
                elif action_name == 'find':
                    self.on.append(objid)
                    current_script = [
                        f'<char0> [walk] <{obj}> ({objid})',
                        f'<char0> [lookat] <{obj}> ({objid})'
                    ]
                elif action_name == 'switchon':
                    current_script = []
                    if "microwave" not in visible_objects:
                        current_script.append(f'<char0> [find] <{obj}> ({objid})')
                    if "toaster" not in visible_objects:
                        current_script.append(f'<char0> [switchon] <{obj}> ({objid})')
                        print(current_script)

                elif action_name == 'grab':
                    current_script = []
                    current_script.append(f'<char0> [grab] <{obj}> ({objid})')
                    
                elif action_name == 'sit':
                    current_script = [f'<char0> [sit] <{obj}> ({objid})']
                elif action_name == 'open':
                    current_script = [f'<char0> [open] <{obj}> ({objid})']
                elif action_name == 'close':
                    current_script = [f'<char0> [close] <{obj}> ({objid})']
                else:
                    current_script = ['<char0> [no_action]']
            else:
                current_script = ['<char0> [no_action]']
        
        elif len(action_list) == 3:
            action_name, obj1, obj2 = action_list
            if obj1 in avail_objects and obj2 in avail_objects:
                objid1 = avail_objects[obj1]
                objid2 = avail_objects[obj2]
                current_script = [f'<char0> [{action_name}] <{obj1}> ({objid1}) <{obj2}> ({objid2})']
            else:
                current_script = ['<char0> [no_action]']
        
        else:
            current_script = ['<char0> [no_action]']
        
        return current_script, action

    def actions_available(self):
        visible_object = self.get_visible_objects()

        name_to_id = dict()
        if visible_object[0]:
            for key, item in visible_object[1].items():
                name_to_id[item] = key

        avail_action = [
            'walkforward',
            'walktowards',
            'turnleft',
            'turnright',
            'no_action',
            'walk',
            'switchon',
            'switchoff',
            'putin'
            'put',
            'open',
            'close',
            'grab',
        ]

        graph_avail_action = [
            'plugin',
            'plugout',
            'check'
        ]
        return avail_action, name_to_id

    def in_graph_action(self, action, object_id):
        entire_graph = self.get_graph()
        temp_node = None
        for node in entire_graph['nodes']:
            if node['id'] == object_id:
                temp_node = node
                break
        if temp_node is None:
            return
        
        if action == 'plugin':
            self.in_graph_states[object_id].append("PLUGIN")
            return object_id, ["PLUGIN"]
        elif action == 'plugout':
            self.in_graph_states[object_id].append("PLUGOUT")
            return object_id, ["PLUGOUT"]
        elif action[0] == 'check':
            return object_id, temp_node['properties'] + temp_node['states'] + self.in_graph_states[action[1]]
        
        else:
            return

    def get_graph(self, mode='json', with_properties=False):
        if self.changed_graph:
            s, graph = self.comm.environment_graph()
            if not s:
                pdb.set_trace()
            self.graph = graph
            self.changed_graph = False

        if mode == 'json':
            return self.graph
        elif mode == 'triples':
            nodes,_ = self.get_nodes()
            object_ids = nodes.keys()

            object_ids = [int(x) for x in object_ids]
            temp_kg = self.graph

            visible_kg = {'nodes': [], 'edges': []}
            id_to_class_name = {}

            for node in temp_kg['nodes']:
                id_to_class_name[node['id']] = node['class_name']
                if node['id'] in object_ids:
                    if node['category'] in ['Walls', 'Ceiling', 'Lamps', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows']:
                        object_ids.remove(node['id'])
                    else:
                        visible_kg['nodes'].append(node['class_name'])
                        if node['states']:
                            for s in node['states']:
                                visible_kg['edges'].append((node['class_name'], "is", s))
                        if with_properties and node['properties']:
                            for p in node['properties']:
                                visible_kg['edges'].append((node['class_name'], "is", p))

            for edge in temp_kg['edges']:
                if edge['from_id'] in object_ids and edge['to_id'] in object_ids and edge['relation_type'] != "FACING":
                    visible_kg['edges'].append((id_to_class_name[edge['from_id']], edge['relation_type'], id_to_class_name[edge['to_id']]))

            return visible_kg
        else:
            raise NotImplementedError()

    def get_nodes(self, entire_nodes=True):
        objects_id_to_name = {}
        objects_name_to_id = {}
        entire_graph = self.get_graph()

        
        for node in entire_graph['nodes']:
            # print(node["class_name"])
            if entire_nodes or node['category'] not in ['Decor', 'Walls', 'Ceiling', 'Lamps', 'Floor', 'Floors', 'Doors', 'Windows']:
                objects_id_to_name[node["id"]] = node['class_name']
                if node['class_name'] not in objects_name_to_id:
                    objects_name_to_id[node['class_name']] = []
                objects_name_to_id[node['class_name']].append(node["id"]) 

        return objects_id_to_name, objects_name_to_id

    def get_visible_graph(self, mode='json'):
        visible_object = self.get_visible_objects()
        if visible_object[0]:
            object_ids = visible_object[1].keys()
            object_ids = [int(x) for x in object_ids]
            temp_kg = self.get_graph()
            # temp_kg = utils.get_visible_nodes(curr_graph, agent_id=(agent_id+1))

            visible_kg = {'nodes': [], 'edges': []}
            id_to_class_name = {}
            for node in temp_kg['nodes']:
                id_to_class_name[node['id']] = node['class_name']
                if node['id'] in object_ids:
                    if node['category'] in ['Walls', 'Ceiling', 'Lamps', 'Characters', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows']:
                        object_ids.remove(node['id'])
                    else:
                        if mode == 'json':
                            visible_kg['nodes'].append(node)
                        else:
                            visible_kg['nodes'].append(node['class_name'])
                            if node['states']:
                                for s in node['states']:
                                    visible_kg['edges'].append((node['class_name'], "is", s))
            for edge in temp_kg['edges']:
                if edge['from_id'] in object_ids and edge['to_id'] in object_ids:
                    if mode == 'json':
                        visible_kg['edges'].append(edge)
                    else:
                        visible_kg['edges'].append((id_to_class_name[edge['from_id']], edge['relation_type'], id_to_class_name[edge['to_id']]))

            return visible_kg
        else:
            return None

    def get_agent_graph(self, mode='json'):
        temp_kg = self.get_graph()
        visible_kg = {'nodes': [], 'edges': []}
        visible_object = []
        id_to_class_name,_ = self.get_nodes()
        for edge in temp_kg['edges']:
            if edge['from_id'] == 1:
                if mode == 'json':
                    visible_kg['edges'].append(edge)
                else:
                    if edge['relation_type'] == 'HOLDS_RH' or edge['relation_type'] == 'HOLDS_LH':
                        rel = 'HOLD'
                    elif edge['relation_type'] == 'CLOSE':
                        continue
                    else:
                        rel = edge['relation_type']
                    visible_kg['edges'].append((id_to_class_name[edge['from_id']], rel, id_to_class_name[edge['to_id']]))
                visible_object.append(edge['to_id'])

        for node in temp_kg['nodes']:
            if node['id'] in visible_object:
                if node['category'] in ['Walls', 'Ceiling', 'Lamps', 'Characters', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows']:
                    visible_object.remove(node['id'])
                else:
                    if mode == 'json':
                        visible_kg['nodes'].append(node)
                    else:
                        visible_kg['nodes'].append(node['class_name'])
        return visible_kg

    def get_agent_room(self):
        agent_graph = self.get_agent_graph(mode='triples')
        for edge in agent_graph['edges']:
            if edge[0] == 'character' and edge[1] == 'INSIDE':
                agentroom = edge[2]
        return agentroom


    def get_visible_objects(self):
        """
        :return: Tuple[bool, Dict[str, str]]
        bool -> object existence
        dict -> {object id: object category}
        """
        if not self.changed_visible_object:
            return True, self.visible_objects_cache

        self.graph = self.get_graph()
        vis_dict = self.get_observation(agent_id=0, obs_type='visible')

        agent_graph = self.get_agent_graph()
        nodes, _ = self.get_nodes()
        id_to_obj = {id: obj for id, obj in nodes.items()}

        for agent_edge in agent_graph["edges"]:
            if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                item_id = agent_edge['to_id']
                if item_id in id_to_obj:
                    item = id_to_obj[item_id]
                    vis_dict[str(item_id)] = item

        room = self.get_agent_room()
        room_id = self.room_ids.get(room)

        sofa_id = self.find_nodes(class_name='sofa')[-1]["id"]
        rug_id = self.find_nodes(class_name='rug')[-1]["id"]
        cereal_id = self.find_nodes(class_name='cereal')[-1]["id"]
        creamybuns_id = self.find_nodes(class_name='creamybuns')[-1]["id"]


        for plan in self.plan[-2:]:
        
            prev_obj_id = []
            if "find" in plan or "open" in plan:
                prev_obj = plan.split()[1]
                prev_obj_id = [id for id, obj in id_to_obj.items() if obj == prev_obj]
                
            if len(prev_obj_id) != 0:
                for i in prev_obj_id:
                    if self.objinroom[i] == room_id:
                        vis_dict[str(i)] = prev_obj


            on_id = self.on[-1]

            objects_on_id = [
                edge['from_id'] for edge in self.graph['edges']
                if edge['to_id'] == on_id and edge['relation_type'] in ['ON', 'INSIDE']
            ]
            
            if "coffeetable" in self.plan[-1]:
                objects_on_id.extend([
                    edge['from_id'] for edge in self.graph['edges']
                    if edge['to_id'] in [sofa_id, rug_id] and edge['relation_type'] == 'ON'
                ])
                bananas = self.find_nodes(class_name='bananas')
                vis_dict[str(bananas[2]['id'])] = "bananas"
                apples = self.find_nodes(class_name='apple')
                for apple in apples:
                    vis_dict[str(apple['id'])] = "apple"

            if "wallshelf" in self.plan[-1]:
                vis_dict[str(creamybuns_id)] = "creamybuns"
                vis_dict[str(cereal_id)] = "cereal"

            if self.plan[-1] in ["grab apple", "grab cellphone"]:
                bananas = self.find_nodes(class_name='bananas')
                vis_dict[str(bananas[2]['id'])] = "bananas"

            for bath_plan in self.plan[-2:]:
                if "bathroomcounter" in bath_plan:
                    bathroomcabinet = self.find_nodes(class_name='bathroomcabinet')
                    vis_dict[str(bathroomcabinet[0]['id'])] = "bathroomcabinet"
                    toothpaste = self.find_nodes(class_name='toothpaste')
                    vis_dict[str(toothpaste[0]['id'])] = "toothpaste"

            for id, obj in id_to_obj.items():
                if id in objects_on_id:
                    vis_dict[str(id)] = obj

        return True, vis_dict

    def get_rgb_to_instance(self):
        instance_colors = self.comm.instance_colors()
        rgb_to_instance = {}
        for k, v in instance_colors[1].items():
            rgb_code = [round(v[2] * 255), round(v[1] * 255), round(v[0] * 255)]
            for i in range(3):
                if rgb_code[i] == 70:
                    rgb_code[i] = 69
            rgb_code = tuple(rgb_code)
            if tuple(v) not in rgb_to_instance.keys():
                rgb_to_instance[rgb_code] = []
            rgb_to_instance[rgb_code].append(k)
        self.rgb_to_instance = rgb_to_instance
    

    def get_observations(self):
        dict_observations = {}
        dict_observations['image'] = self.get_observation(agent_id=0, obs_type='image', info={
                                    'image_width':640, 'image_height': 480,})
        dict_observations['visible'] = self.get_observation(agent_id=0, obs_type='visible', info={
                                'image_width':640, 'image_height':480})
        # dict_observations['room'] = self.get_agent_room()
        dict_observations['visible_graph'] = self.get_visible_graph(mode='triples')
        dict_observations['agent_graph'] = self.get_agent_graph(mode='triples')
        dict_observations['entire_grap'] = self.get_graph()
        return dict_observations
    
    def get_action_space(self):
        dict_action_space = {}
        for agent_id in range(self.num_agents):
            if self.observation_types[agent_id] not in ['partial', 'full']:
                raise NotImplementedError
            else:
                # Even if you can see all the graph, you can only interact with visible objects
                obs_type = 'partial'
            visible_graph = self.get_observation(agent_id, obs_type)
            dict_action_space[agent_id] = [node['id'] for node in visible_graph['nodes']]
        return dict_action_space

    def get_observation(self, agent_id, obs_type, info={}):
        if obs_type == 'partial':
            # agent 0 has id (0 + 1)
            curr_graph = self.get_graph()
            return utils.get_visible_nodes(curr_graph, agent_id=(agent_id+1))

        elif obs_type == 'full':
            return self.get_graph()

        elif obs_type == 'visible':
            camera_ids = [self.num_static_cameras + x   for x in self.CAMERA_NUM]
            visible_objects = {}
            for cam_id in camera_ids:
                response = self.comm.get_visible_objects(cam_id)
                if response[1]:
                    visible_objects.update(response[1])
                break
            return visible_objects

        elif obs_type == 'image':
            if 'mode' in info:
                current_mode = info['mode']
            else:
                current_mode = 'normal'

            camera_ids = [self.num_static_cameras + x   for x in self.CAMERA_NUM]
            _, ncameras = self.comm.camera_count()
            cameras_select = list(range(ncameras))
            cameras_select = [cameras_select[x] for x in camera_ids]

            if 'image_width' in info:
                image_width = info['image_width']
                image_height = info['image_height']
            else:
                image_width, image_height = self.default_image_width, self.default_image_height

            s, images = self.comm.camera_image(cameras_select, mode=current_mode, image_width=image_width, image_height=image_height)

            images = [cv2.cvtColor(images[0], cv2.COLOR_RGB2BGR)]
            if not s:
                pdb.set_trace()

            return images
        else:
            raise NotImplementedError


    def convert_action_room(self, action):
        move_commands = {
            'on': ['walk src','grab src','walk tgt','putback src tgt'],
            'in': ['walk src','grab src','walk tgt','putin src tgt'],
            'open': ['walk src','grab src','walk tgt','open tgt','putin src tgt','close tgt']
        }  
        src_obj,loc,tgt_obj = action.split(' ')  
        src_obj,src_room = src_obj.split('_')
        tgt_obj,tgt_room = tgt_obj.split('_')

        src_room_id = self.room_ids[src_room]
        tgt_room_id = self.room_ids[tgt_room]
        _,class_name_to_id = self.get_nodes()    

        src_obj_ids = class_name_to_id[src_obj]
        src_obj_ids = [id for id in src_obj_ids if self.objinroom[id] == src_room_id]
        tgt_obj_ids = class_name_to_id[tgt_obj]
        tgt_obj_ids = [id for id in tgt_obj_ids if self.objinroom[id] == tgt_room_id]

        move_obj_id = {
            src_obj: src_obj_ids[0],
            tgt_obj: tgt_obj_ids[0]
        }
        scripts = []
    
        for command in move_commands[loc]:
            command = command.replace('src',src_obj).replace('tgt',tgt_obj)
            command_split = command.split(' ')
            if 'put' in command_split[0]:
                action,src_obj,tgt_obj = command_split
                src_obj_id = move_obj_id[src_obj]
                tgt_obj_id = move_obj_id[tgt_obj]
                current_script = ['<char{}> [{}] <{}> ({}) <{}> ({})'.format(0,action,src_obj,src_obj_id,tgt_obj,tgt_obj_id)]
            else:
                action,obj = command_split
                obj_id = move_obj_id[obj]
                current_script = ['<char{}> [{}] <{}> ({})'.format(0,action,obj,obj_id)]
            scripts.append(current_script)

        return scripts


    def goalCond(self, goals):

        cgc = 0

        entire_graph = self.get_graph()

        nodes,_ = self.get_nodes()
        id_to_obj = {id:obj for id,obj in nodes.items()}
        
        for idx, goal in enumerate(goals):
            # print(goal)

            if "state" in goal.keys():
                goal_states = goal['state']
                class_name = goal["class_name"]
                class_info = self.find_nodes(class_name=class_name)
                states = class_info[0]["states"]
                for state in goal_states:
                    if state in states:
                        cgc += 1
                
                
            elif "relation_type" in goal.keys():
                from_id = goal["from_id"]
                relation_type = goal["relation_type"]
                to_id = goal["to_id"]
                from_id_idx = [key for key, value in id_to_obj.items() if value == from_id]
                to_id_idx = [key for key, value in id_to_obj.items() if value == to_id]
                    
                for edge in entire_graph['edges']:
                    if edge['from_id'] in from_id_idx and edge['relation_type'] in relation_type and edge['to_id'] in to_id_idx:
                        cgc += 1
                        break
            
        return cgc/len(goals)


class CrossDomain(BaseEnvironment):
    def __init__(self,
                num_agents=1,
                max_episode_length=500,
                observation_types=None,
                use_editor=False,
                base_port=8080,
                port_id=0,
                executable_args={},
                recording_options=None,
                seed=123,
                ):
        self.seed = seed
        self.prev_reward = 0.
        self.rnd = random.Random(seed)
        np.random.seed(seed)

        if recording_options is None:
            self.recording_options = {'recording': False,
                                    'output_folder': None,
                                    'file_name_prefix': None,
                                    'cameras': 'PERSON_FROM_BACK',
                                    'modality': 'normal'}
        else:
            self.recording_options = recording_options

        self.steps = 0
        self.env_id = 0
        self.max_ids = {}

        self.in_graph_states = defaultdict(list)
        self.num_agents = num_agents
        self.max_episode_length = max_episode_length
        self.base_port = base_port
        self.port_id = port_id
        self.executable_args = executable_args

        self.num_camera_per_agent = 6
        self.CAMERA_NUM = [2,3] 
        self.default_image_width = 640
        self.default_image_height = 480

        if observation_types is not None:
            self.observation_types = observation_types
        else:
            self.observation_types = ['partial' for _ in range(num_agents)]

        self.agent_info = {
            0: 'Chars/Male1'
        }

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True
        self.rooms = None
        self.id2node = None
        self.num_static_cameras = None
        self.success_condition = None
        self.rgb_to_instance = None
        self.prev_observable_objects = {}
        self.visible_objects_cache = {}
        self.image_cache = None
        
        self.on = []

        with open('prog_cond.json', 'r') as f:
            self.prog_cond = json.load(f)

        if use_editor:
            # Use Unity Editor
            self.port_number = 8080
            self.comm = comm_unity.UnityCommunication()
        else:
            self.port_number = self.base_port + port_id
            self.comm = comm_unity.UnityCommunication()

        atexit.register(self.close)
        # self.reset()

    def close(self):
        self.comm.close()

    def relaunch(self):
        self.comm.close()
        self.comm = comm_unity.UnityCommunication()


    def find_nodes(self, **kwargs):
        if len(kwargs) == 0:
            return None
        else:
            k, v = next(iter(kwargs.items()))
            return [n for n in self.graph['nodes'] if n[k] == v]

    def add_node(self,  n):
        self.graph['nodes'].append(n)

    def add_edge(self,  fr_id, rel, to_id):
        self.graph['edges'].append({'from_id': fr_id, 'relation_type': rel, 'to_id': to_id})

    def reward(self):
        # Define here your reward
        if self.success_condition is None:
            reward = 0
            done = False
            info = {}
        else:
            info = {'success_condition': self.success_condition}
            success = [False] * len(self.success_condition)

            reward = 0
            entire_graph = self.get_graph(mode='triples')
            agent_graph = self.get_agent_graph(mode='triples')


            for edge in entire_graph['edges']:
                for idx, suc_cond in enumerate(self.success_condition):
                    if REL_EQUAL(edge, suc_cond):
                        success[idx] = True

            for edge in agent_graph['edges']:
                for idx, suc_cond in enumerate(self.success_condition):
                    if REL_EQUAL(edge, suc_cond):
                        success[idx] = True

            for i in range(len(success)):
                if success[i]:
                    reward += 1
            if reward == len(self.success_condition):
                info['is_success'] = True
                done = True
            else:
                info['is_success'] = False
                done = False
        return reward, done, info

    def step(self, action_dict):
        
        myroom = self.get_agent_room()

        if "put in" in action_dict:
            action_dict = action_dict.replace("put in", "putin")

        script_list = self.convert_action(action_dict, self.get_visible_objects(action_dict)[1])
        # print(script_list)
        if len(script_list[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script_list,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=False,
                                                        camera_mode=self.recording_options['cameras'])
            else:
                success, message = self.comm.render_script(script_list,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
        # Obtain reward
        
        self.plan.append(action_dict)
        reward, done, info = self.reward()
        self.steps += 1
        obs = self.get_observations()
        info['finished'] = done
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom
        return obs, reward, done, info

    def step_init(self, action_dict):
        

        script_list = self.convert_action_init(action_dict, self.get_visible_objects()[1])
        # print(script_list)
        if len(script_list[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script_list,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=True,
                                                        camera_mode=self.recording_options['cameras'])
            else:
                success, message = self.comm.render_script(script_list,
                                                        recording=True,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
        # Obtain reward
        
        self.plan.append(action_dict)
        
    def baseline_step(self, action_dict):
        myroom = self.get_agent_room()
        avail_actions, avail_objects = self.actions_available()
        script, action = self.affordance(action_dict, self.get_visible_objects()[1])   
        
        if len(script[0]) > 0:
            if self.recording_options['recording']:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        skip_animation=False,
                                                        find_solution=True
                                                        )
            else:
                success, message = self.comm.render_script(script,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                                        )
           
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
        # Obtain reward
        
        self.plan.append(action)
        
        reward, done, info = self.reward()
        self.steps += 1

        obs = self.get_observations()
        info['finished'] = done
        self.prev_observable_objects = avail_objects
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom
        info['visible'] = self.get_visible_objects()[1]

        return obs,  info, action



    def prog_step(self, action_dict):
        myroom = self.get_agent_room()

        avail_actions, avail_objects = self.actions_available()

        script, action = self.prog_convert_action(action_dict, self.get_visible_objects()[1])
        
        if action == False:
            for skill in script:
                pre_script, action = self.prog_convert_action(skill, self.get_visible_objects()[1])

                success, message = self.comm.render_script(pre_script,
                                                        recording=False,
                                                        find_solution=False,
                                                        skip_animation=True,
                                
                                                        )
        
                if not success:
                    print(message)
                    return False
                else:
                    self.changed_graph = True
                    self.changed_visible_object = True
                    self.changed_image_view = True
                        
                # Obtain reward
                self.plan.append(action)

            
                
        else:   
            success, message = self.comm.render_script(script,
                                                    recording=False,
                                                    find_solution=False,
                                                    skip_animation=True,
                                                    )
        
            if not success:
                print(message)
                return False
            else:
                self.changed_graph = True
                self.changed_visible_object = True
                self.changed_image_view = True
                
        # Obtain reward
            self.plan.append(action)
        
        reward, done, info = self.reward()
        self.steps += 1

        obs = self.get_observations()
        info['finished'] = done
        self.prev_observable_objects = avail_objects
        if self.steps == self.max_episode_length:
            done = True
        info['success'] = success
        info["myroom"] = myroom
        info['visible'] = self.get_visible_objects()[1]

        return obs,  info, action
  

    def reset(self, environment_graph=None, environment_id=None, init_rooms=None,init_pos=None,clear_space=None, angle=None):
        
        """
        :param environment_graph: the initial graph we should reset the environment with
        :param environment_id: which id to start
        :param init_rooms: where to intialize the agents
        """
        self.env_id = environment_id
        # print("Resetting env", self.env_id)

        if self.env_id is not None:
            self.comm.reset(self.env_id)
        else:
            self.comm.reset()
        self.plan = []

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True

        sucess, self.graph = self.comm.environment_graph()

        if self.env_id not in self.max_ids.keys():
            max_id = max([node['id'] for node in self.graph['nodes']])
            self.max_ids[self.env_id] = max_id

        if environment_graph is None:
            entire_graph = self.get_graph()
            classwise_obj = []
            node_ids = []
            updated_graph = {'nodes': [], 'edges': []}

            _,class_name_to_id = self.get_nodes()
            ids_to_remove = []
            if clear_space:
                ids_to_remove = self.remove_objects(clear_space)
                
            ##########
            for node in entire_graph['nodes']:
                if ((node['category'] not in ['Rooms', 'Walls', 'Ceiling', 'Lamps', 'Characters', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows'] )):
                    if node['id'] not in ids_to_remove:
                        classwise_obj.append(node['class_name'])
                        updated_graph['nodes'].append(node)
                        node_ids.append(node['id'])
                else:
                    updated_graph['nodes'].append(node)
                    node_ids.append(node['id'])
            for edge in entire_graph['edges']:
                if edge['from_id'] in node_ids and edge['to_id'] in node_ids:
                    updated_graph['edges'].append(edge)
    
            bed = self.find_nodes(class_name='bed')

            self.add_node({'class_name': 'cat', 
                'category': 'Animals', 
                'id': 1000, 
                'properties': [], 
                'states': ["ON"]})
            self.add_edge(1000, 'ON', bed[-1]['id'])
            
            success, m = self.comm.expand_scene(entire_graph)
        else:
            updated_graph = environment_graph
            success, m = self.comm.expand_scene(updated_graph)

        if not success:
            print("Error expanding scene")
            # pdb.set_trace()
            # return None
        self.num_static_cameras = self.comm.camera_count()[1]

        
        if init_rooms is None or init_rooms[0] not in ['kitchen', 'bedroom', 'livingroom', 'bathroom']:
            rooms = self.rnd.sample(['kitchen', 'bedroom', 'livingroom', 'bathroom'], 2)
        else:
            rooms = list(init_rooms)
        # print(self.comm.character_cameras())
        for i in range(self.num_agents):
            if i in self.agent_info:
                if init_pos:
                    self.comm.add_character(self.agent_info[i], position=init_pos)
                else:
                    self.comm.add_character(self.agent_info[i], initial_room=rooms[i])
            else:
                self.comm.add_character()


   
        self.comm.update_camera(self.num_static_cameras + self.CAMERA_NUM[0] , field_view=80, position=[0, 1.6, 0])

        self.changed_graph = True
        self.changed_visible_object = True
        self.changed_image_view = True

        self.prev_observable_objects = {}
        graph = self.get_graph()
        self.rooms = {node['id']: node['class_name'] for node in graph['nodes'] if node['category'] == 'Rooms'}  ## class id : room
        self.room_ids = {v:k for k,v in self.rooms.items()}

        self.id2node = {node['id']: node for node in graph['nodes']}
        
        self.in_graph_states = defaultdict(list)
        obs = self.get_observations()
        self.steps = 0
        self.prev_reward = 0.
        return obs


    def objinroom(self):
        graph = self.get_graph()
        return {edge['from_id']: edge['to_id'] for edge in graph['edges'] if edge['relation_type'] == 'INSIDE' and edge['from_id'] != 1 and edge['to_id'] in self.rooms}


    def convert_action(self, action, avail_objects):
        avail_objects = {v: k for k, v in avail_objects.items()}
        action_list = action.split()
        
        current_script = ['<char0> [no_action]']
        
        if len(action_list) == 1:
            action_name = action_list[0]
            current_script = [f'<char0> [{action_name}]']
            
            if action_name in ['turnright', 'turnleft']:
                current_script.extend([f'<char0> [{action_name}]'] * 2)
            elif action_name == 'walkforward':
                current_script.append(f'<char0> [walkforward]')

        elif len(action_list) == 2:
            action_name, obj = action_list

            if obj in avail_objects.keys():
                current_script = [f'<char0> [{action_name}] <{obj}> ({avail_objects[obj]})']

            if action_name == 'walk':
                nodes, _ = self.get_nodes()
                obj_to_id = {obj: id for id, obj in nodes.items()}
                if obj in obj_to_id:
                    objid = obj_to_id[obj]
                    current_script = [f'<char0> [walk] <{obj}> ({objid})']

            elif action_name == 'find':
                nodes, _ = self.get_nodes()
                obj_to_id = {obj: id for id, obj in nodes.items()}
                if obj in obj_to_id:
                    objid = obj_to_id[obj]
                    current_script = [
                        f'<char0> [walk] <{obj}> ({objid})',
                        f'<char0> [lookat] <{obj}> ({objid})'
                    ]
                    self.on.append(objid)

            elif action_name in ['grab', 'sit', 'open', 'close']:
                print(obj, avail_objects.keys())

                if obj in avail_objects.keys():
                    
                    current_script = [f'<char0> [{action_name}] <{obj}> ({avail_objects[obj]})']

        elif len(action_list) == 3:
            action_name, obj1, obj2 = action_list

            if obj1 in avail_objects.keys() and obj2 in avail_objects.keys():
                current_script = [f'<char0> [{action_name}] <{obj1}> ({avail_objects[obj1]}) <{obj2}> ({avail_objects[obj2]})']

            if action_name == 'put':
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes, _ = self.get_nodes()
                        id_to_obj = {id: obj for id, obj in nodes.items()}
                        if item_id in id_to_obj:
                            item = id_to_obj[item_id]
                            grab_dict[item] = item_id

                if obj1 in grab_dict.keys() and obj2 in avail_objects.keys():
                    current_script = [f'<char0> [put] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({avail_objects[obj2]})']

            elif action_name == 'putin':
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes, _ = self.get_nodes()
                        id_to_obj = {id: obj for id, obj in nodes.items()}
                        if item_id in id_to_obj:
                            item = id_to_obj[item_id]
                            grab_dict[item] = item_id

                obj2_states = self.find_nodes(class_name=obj2)
                obj2_ids = [obj_state["states"] for obj_state in obj2_states]
                put_avail_objects = avail_objects if ["OPEN"] in obj2_ids else {k: v for k, v in avail_objects.items() if k != obj2}

                if obj1 in grab_dict.keys() and obj2 in put_avail_objects.keys():
                    current_script = [f'<char0> [putin] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({put_avail_objects[obj2]})']

            elif action_name == 'switch':
                action = f'switch{obj1}'
                if obj2 in avail_objects:
                    current_script = [f'<char0> [{action}] <{obj2}> ({avail_objects[obj2]})']

        return current_script
    
    
    def convert_action_init(self, action, avail_objects):
        
        avail_objects = {v: k for k, v in avail_objects.items()}
        action_list = action.split()
        current_script = ['<char0> [no_action]']
        
            
        if len(action_list) == 1:
            action_name = action_list[0]
            current_script = [f'<char0> [{action_name}]']
            
            if action_name in ['turnright', 'turnleft']:
                current_script.extend([f'<char0> [{action_name}]'] * 2)
            elif action_name == 'walkforward':
                current_script.append(f'<char0> [walkforward]')

        elif len(action_list) == 2:
            action_name, obj = action_list

            nodes, _ = self.get_nodes()
            obj_to_id = {obj: id for id, obj in nodes.items()}
            if obj in obj_to_id:
                objid = obj_to_id[obj]

            current_script = [f'<char0> [{action_name}] <{obj}> ({objid})']

            if action_name == "find":
                self.on.append(objid)


        elif len(action_list) == 3:
            action_name, obj1, obj2 = action_list
            
            nodes, _ = self.get_nodes()
            obj_to_id = {obj: id for id, obj in nodes.items()}
            if obj1 in obj_to_id:
                objid1 = obj_to_id[obj1]
            if obj2 in obj_to_id:
                objid2 = obj_to_id[obj2]   
            
            current_script = [f'<char0> [{action_name}] <{obj1}> ({objid1}) <{obj2}> ({objid2})']

        return current_script
    
    def remove_prefix(self, text):
        import re
        return re.sub(r'^\d+\.\s*', '', text)

    def affordance(self, skill_list, avail_objects):
        with open('../../../Task_Data/vh_config.json', 'r') as f:
            vh_config = json.load(f)

        avail_objects = {v: k for k, v in avail_objects.items()}

        for action in skill_list:
            if '.' in action:
                action = self.remove_prefix(action)
            if 'done' in action:
                action = action.replace(', done', '')
            if "put in" in action:
                action = action.replace("put in", "putin")
            
            action_list = action.split()

            current_script = ['<char0> [no_action]']

            if len(action_list) == 1:
                action_name = action_list[0]
                current_script = [f'<char0> [{action_name}]']
                if action_name in ['turnright', 'turnleft']:
                    current_script.extend([f'<char0> [{action_name}]'] * 2)
                elif action_name == 'walkforward':
                    current_script.append(f'<char0> [walkforward]')

            elif len(action_list) == 2:
                action_name, obj = action_list

                if action_name in ["find", "walk"]:

                    if action_name == 'walk':
                        current_script = [f'<char0> [walk] <{obj}> ({objid})']
                    elif action_name == 'find':
                     
                        nodes, _ = self.get_nodes()
                        obj_to_id = {obj: id for id, obj in nodes.items()}
                        objid = obj_to_id[obj]
                        self.on.append(objid)
                        current_script = [
                            f'<char0> [walk] <{obj}> ({objid})',
                            f'<char0> [lookat] <{obj}> ({objid})'
                        ]

                elif obj in avail_objects:
                    objid = avail_objects[obj]
                    current_script = [f'<char0> [{action_name}] <{obj}> ({objid})']

                    if action_name == 'grab':
                        current_script = [f'<char0> [grab] <{obj}> ({objid})']
                    elif action_name == 'sit':
                        current_script = [f'<char0> [sit] <{obj}> ({objid})']
                    elif action_name == 'open':
                        current_script = [f'<char0> [open] <{obj}> ({objid})']
                    elif action_name == 'close':
                        current_script = [f'<char0> [close] <{obj}> ({objid})']
                        
                        
            elif len(action_list) == 3:
                action_name, obj1, obj2 = action_list
                if obj1 in avail_objects and obj2 in avail_objects:
                    objid1 = avail_objects[obj1]
                    objid2 = avail_objects[obj2]
                    current_script = [f'<char0> [{action_name}] <{obj1}> ({objid1}) <{obj2}> ({objid2})']

                if action_name == 'put':
                    agent_graph = self.get_agent_graph()
                    grab_dict = self._get_grab_dict(agent_graph)

                    if obj1 in grab_dict and obj2 in avail_objects:
                        current_script = [
                            f'<char0> [put] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({avail_objects[obj2]})'
                        ]

                elif action_name == 'putin':
                    agent_graph = self.get_agent_graph()
                    grab_dict = self._get_grab_dict(agent_graph)

                    if obj1 in grab_dict and obj2 in avail_objects:
                        current_script = [
                            f'<char0> [putin] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({avail_objects[obj2]})'
                        ]

                elif action_name == "switch":
                    action = f"switch{action_list[1]}"
                    current_script = [
                        f'<char0> [{action}] <{obj2}> ({avail_objects[obj2]})'
                    ]

            if current_script == ['<char0> [no_action]']:
                continue
            else:
                break

        return current_script, action

    def _get_grab_dict(self, agent_graph):
        grab_dict = {}
        for agent_edge in agent_graph["edges"]:
            if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                item_id = agent_edge['to_id']
                nodes, _ = self.get_nodes()
                id_to_obj = {id: obj for id, obj in nodes.items()}
                if item_id in id_to_obj:
                    item = id_to_obj[item_id]
                    grab_dict[item] = item_id
        return grab_dict

    def progprompt_action(self, action, visible_objects):
 
        nodes, _ = self.get_nodes()
        avail_objects = {obj: id for id, obj in nodes.items()}
        
        if "switch" in action:
            _, _, obj = action.split()
            action = f"switchon {obj}"
        elif "put in" in action:
            _, _, obj1, obj2 = action.split()
            action = f"putin {obj1} {obj2}"
        
        action_list = action.split()
        
        if len(action_list) == 1:
            action_name = action_list[0]
            current_script = [f'<char0> [{action_name}]']
            
            if action_name in ['turnright', 'turnleft']:
                current_script.extend([f'<char0> [{action_name}]'] * 2)
            elif action_name == 'walkforward':
                current_script.append(f'<char0> [walkforward]')
        
        elif len(action_list) == 2:
            action_name, obj = action_list
            
            if obj in avail_objects:
                objid = avail_objects[obj]
                current_script = [f'<char0> [{action_name}] <{obj}> ({objid})']
                
                if action_name == 'walk':
                    current_script = [f'<char0> [walk] <{obj}> ({objid})']
                elif action_name == 'find':
                    self.on.append(objid)
                    current_script = [
                        f'<char0> [walk] <{obj}> ({objid})',
                        f'<char0> [lookat] <{obj}> ({objid})'
                    ]
                elif action_name == 'switchon':
                    current_script = []
                    if "microwave" not in visible_objects:
                        current_script.append("find microwave")
                    current_script.append([f'<char0> [switchon] <{obj}> ({objid})'])
                elif action_name == 'grab':
                    current_script = []
                    current_script.append(f'<char0> [grab] <{obj}> ({objid})')
                    
                elif action_name == 'sit':
                    current_script = [f'<char0> [sit] <{obj}> ({objid})']
                elif action_name == 'open':
                    current_script = [f'<char0> [open] <{obj}> ({objid})']
                elif action_name == 'close':
                    current_script = [f'<char0> [close] <{obj}> ({objid})']
                else:
                    current_script = ['<char0> [no_action]']
            else:
                current_script = ['<char0> [no_action]']
        
        elif len(action_list) == 3:
            action_name, obj1, obj2 = action_list
            if obj1 in avail_objects and obj2 in avail_objects:
                objid1 = avail_objects[obj1]
                objid2 = avail_objects[obj2]
                current_script = [f'<char0> [{action_name}] <{obj1}> ({objid1}) <{obj2}> ({objid2})']
            else:
                current_script = ['<char0> [no_action]']
        
        else:
            current_script = ['<char0> [no_action]']
        
        return current_script, action


    def prog_convert_action(self, action, avail_objects):
        avail_objects = {v: k for k, v in avail_objects.items()}
        action_list = action.split()
        
        current_script = ['<char0> [no_action]']
        
        if len(action_list) == 1:
            action_name = action_list[0]
            current_script = [f'<char0> [{action_name}]']
            
            if action_name in ['turnright', 'turnleft']:
                current_script.extend([f'<char0> [{action_name}]'] * 2)
            elif action_name == 'walkforward':
                current_script.append(f'<char0> [walkforward]')

        elif len(action_list) == 2:
            action_name, obj = action_list

            if obj in avail_objects.keys():
                current_script = [f'<char0> [{action_name}] <{obj}> ({avail_objects[obj]})']

            if action_name == 'walk':
                nodes, _ = self.get_nodes()
                obj_to_id = {obj: id for id, obj in nodes.items()}
                if obj in obj_to_id:
                    objid = obj_to_id[obj]
                    current_script = [f'<char0> [walk] <{obj}> ({objid})']

            elif action_name == 'find':
                nodes, _ = self.get_nodes()
                obj_to_id = {obj: id for id, obj in nodes.items()}
                if obj in obj_to_id:
                    objid = obj_to_id[obj]
                    current_script = [
                        f'<char0> [walk] <{obj}> ({objid})',
                        f'<char0> [lookat] <{obj}> ({objid})'
                    ]
                    self.on.append(objid)

            elif action_name in ['grab', 'sit', 'open', 'close']:
                current_script = []
                if obj in avail_objects.keys():
                    current_script = [f'<char0> [{action_name}] <{obj}> ({avail_objects[obj]})']
                else:
                    if action in self.prog_cond.keys():
                        return self.prog_cond[action], False


        elif len(action_list) == 3:
            action_name, obj1, obj2 = action_list

            if obj1 in avail_objects.keys() and obj2 in avail_objects.keys():
                current_script = [f'<char0> [{action_name}] <{obj1}> ({avail_objects[obj1]}) <{obj2}> ({avail_objects[obj2]})']

            if action_name == 'put':
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes, _ = self.get_nodes()
                        id_to_obj = {id: obj for id, obj in nodes.items()}
                        if item_id in id_to_obj:
                            item = id_to_obj[item_id]
                            grab_dict[item] = item_id

                if obj1 in grab_dict.keys() and obj2 in avail_objects.keys():
                    current_script = [f'<char0> [put] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({avail_objects[obj2]})']
                    
                    
                else:
                    if action in self.prog_cond.keys():
                        return self.prog_cond[action], False

            elif action_name == 'putin':
                agent_graph = self.get_agent_graph()
                grab_dict = {}
                for agent_edge in agent_graph["edges"]:
                    if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                        item_id = agent_edge['to_id']
                        nodes, _ = self.get_nodes()
                        id_to_obj = {id: obj for id, obj in nodes.items()}
                        if item_id in id_to_obj:
                            item = id_to_obj[item_id]
                            grab_dict[item] = item_id

                obj2_states = self.find_nodes(class_name=obj2)
                obj2_ids = [obj_state["states"] for obj_state in obj2_states]
                put_avail_objects = avail_objects if ["OPEN"] in obj2_ids else {k: v for k, v in avail_objects.items() if k != obj2}

                if obj1 in grab_dict.keys() and obj2 in avail_objects.keys():
                    current_script = [f'<char0> [putin] <{obj1}> ({grab_dict[obj1]}) <{obj2}> ({put_avail_objects[obj2]})']
                    
                else:
                    if action in self.prog_cond.keys():
                        return self.prog_cond[action], False

            elif action_name == 'switch':
                action = f'switch{obj1}'
                if obj2 in avail_objects:
                    current_script = [f'<char0> [{action}] <{obj2}> ({avail_objects[obj2]})']
                    
                else:
                    if action in self.prog_cond.keys():
                        return self.prog_cond[action], False

        return current_script, action
    
    
    def remove_history(self):
        self.plan = []


    def actions_available(self):
        visible_object = self.get_visible_objects()

        name_to_id = dict()
        if visible_object[0]:
            for key, item in visible_object[1].items():
                name_to_id[item] = key

        avail_action = [
            'walkforward',
            'walktowards',
            'turnleft',
            'turnright',
            'no_action',
            'walk',
            'switchon',
            'switchoff',
            'putin'
            'put',
            'open',
            'close',
            'grab',
        ]

        graph_avail_action = [
            'plugin',
            'plugout',
            'check'
        ]
        return avail_action, name_to_id

    def get_graph(self, mode='json', with_properties=False):
        if self.changed_graph:
            s, graph = self.comm.environment_graph()
            if not s:
                pdb.set_trace()
            self.graph = graph
            self.changed_graph = False

        if mode == 'json':
            return self.graph
        elif mode == 'triples':
            nodes,_ = self.get_nodes()
            object_ids = nodes.keys()

            object_ids = [int(x) for x in object_ids]
            temp_kg = self.graph

            visible_kg = {'nodes': [], 'edges': []}
            id_to_class_name = {}

            for node in temp_kg['nodes']:
                id_to_class_name[node['id']] = node['class_name']
                if node['id'] in object_ids:
                    if node['category'] in ['Walls', 'Ceiling', 'Lamps', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows']:
                        object_ids.remove(node['id'])
                    else:
                        visible_kg['nodes'].append(node['class_name'])
                        if node['states']:
                            for s in node['states']:
                                visible_kg['edges'].append((node['class_name'], "is", s))
                        if with_properties and node['properties']:
                            for p in node['properties']:
                                visible_kg['edges'].append((node['class_name'], "is", p))

            for edge in temp_kg['edges']:
                if edge['from_id'] in object_ids and edge['to_id'] in object_ids and edge['relation_type'] != "FACING":
                    visible_kg['edges'].append((id_to_class_name[edge['from_id']], edge['relation_type'], id_to_class_name[edge['to_id']]))

            return visible_kg
        else:
            raise NotImplementedError()

    def get_nodes(self, entire_nodes=True):
        objects_id_to_name = {}
        objects_name_to_id = {}
        entire_graph = self.get_graph()

        
        for node in entire_graph['nodes']:
            # print(node["class_name"])
            if entire_nodes or node['category'] not in ['Decor', 'Walls', 'Ceiling', 'Lamps', 'Floor', 'Floors', 'Doors', 'Windows']:
                objects_id_to_name[node["id"]] = node['class_name']
                if node['class_name'] not in objects_name_to_id:
                    objects_name_to_id[node['class_name']] = []
                objects_name_to_id[node['class_name']].append(node["id"]) 

        return objects_id_to_name, objects_name_to_id

    def get_agent_graph(self, mode='json'):
        temp_kg = self.get_graph()
        visible_kg = {'nodes': [], 'edges': []}
        visible_object = []
        id_to_class_name,_ = self.get_nodes()
        for edge in temp_kg['edges']:
            if edge['from_id'] == 1:
                if mode == 'json':
                    visible_kg['edges'].append(edge)
                else:
                    if edge['relation_type'] == 'HOLDS_RH' or edge['relation_type'] == 'HOLDS_LH':
                        rel = 'HOLD'
                    elif edge['relation_type'] == 'CLOSE':
                        continue
                    else:
                        rel = edge['relation_type']
                    visible_kg['edges'].append((id_to_class_name[edge['from_id']], rel, id_to_class_name[edge['to_id']]))
                visible_object.append(edge['to_id'])

        for node in temp_kg['nodes']:
            if node['id'] in visible_object:
                if node['category'] in ['Walls', 'Ceiling', 'Lamps', 'Characters', 'Floor', 'Decor', 'Floors', 'Doors', 'Windows']:
                    visible_object.remove(node['id'])
                else:
                    if mode == 'json':
                        visible_kg['nodes'].append(node)
                    else:
                        visible_kg['nodes'].append(node['class_name'])
        return visible_kg

    def get_agent_room(self):
        agent_graph = self.get_agent_graph(mode='triples')
        for edge in agent_graph['edges']:
            if edge[0] == 'character' and edge[1] == 'INSIDE':
                agentroom = edge[2]
        return agentroom

    def get_visible_objects(self, action_dict = None):
        """
        :return: Tuple[bool, Dict[str, str]]
        bool -> object existence
        dict -> {object id: object category}
        """
        
    def get_visible_objects(self, action_dict=None):
        """
        :return: Tuple[bool, Dict[str, str]]
        bool -> object existence
        dict -> {object id: object category}
        """

        # Get the items in the current room
        room_items = self.objinroom()
        
        # Return cached visible objects if they haven't changed
        if not self.changed_visible_object:
            return True, self.visible_objects_cache

        # Update the graph and visible objects dictionary
        self.graph = self.get_graph()
        vis_dict = self.get_observation(agent_id=0, obs_type='visible')

        # Get the agent's graph and nodes
        agent_graph = self.get_agent_graph()
        nodes, _ = self.get_nodes()
        id_to_obj = {id: obj for id, obj in nodes.items()}

        # Include objects held by the agent
        for agent_edge in agent_graph["edges"]:
            if agent_edge['relation_type'] in ['HOLDS_RH', 'HOLDS_LH']:
                item_id = agent_edge['to_id']
                if item_id in id_to_obj:
                    item = id_to_obj[item_id]
                    vis_dict[str(item_id)] = item

        # Get the current room and its ID
        room = self.get_agent_room()
        room_id = self.room_ids.get(room)

        # Find specific objects by their class names
        cereal_id = self.find_nodes(class_name='cereal')[-1]["id"]
        creamybuns_id = self.find_nodes(class_name='creamybuns')[-1]["id"]
        sofa_id = self.find_nodes(class_name='sofa')[-1]["id"]
        room_items[sofa_id] = self.room_ids.get("livingroom")
        
        # Update the visible objects based on the current plan
        for plan in self.plan[-2:]:
            if "find" in plan or "open" in plan:
                prev_obj = plan.split()[1]
                prev_obj_ids = [id for id, obj in id_to_obj.items() if obj == prev_obj]
                
                for i in prev_obj_ids:
                    if i in room_items.keys() and room_items[i] == room_id:
                        vis_dict[str(i)] = prev_obj

        # Include objects that are on or inside the currently focused object
        if self.on:
            on_id = self.on[-1]
            objects_on_id = [
                edge['from_id'] for edge in self.graph['edges']
                if edge['to_id'] == on_id and edge['relation_type'] in ['ON', 'INSIDE']
            ]
            
            for id, obj in id_to_obj.items():
                if id in objects_on_id:
                    vis_dict[str(id)] = obj

        if len(self.plan) > 0:

            if "coffeetable" in self.plan[-1]:
                
                bananas = self.find_nodes(class_name='bananas')
                vis_dict[str(bananas[2]['id'])] = "bananas"
                apples = self.find_nodes(class_name='apple')
                for apple in apples:
                    vis_dict[str(apple['id'])] = "apple"
                vis_dict[str(creamybuns_id)] = "creamybuns"
                vis_dict[str(cereal_id)] = "cereal"
                
            if self.get_agent_room() == "kitchen":
                if "cat" in vis_dict.values():
                    cat_key = [k for k, v in vis_dict.items() if v == "cat"]
                    del vis_dict[cat_key[0]]

        return True, vis_dict
    
    def get_observations(self):
        dict_observations = {}
        dict_observations['image'] = self.get_observation(agent_id=0, obs_type='image', info={
                                    'image_width':640, 'image_height': 480,})
        dict_observations['visible'] = self.get_observation(agent_id=0, obs_type='visible', info={
                                'image_width':640, 'image_height':480})
        # dict_observations['room'] = self.get_agent_room()
        dict_observations['agent_graph'] = self.get_agent_graph(mode='triples')
        dict_observations['entire_grap'] = self.get_graph()
        return dict_observations

    def get_observation(self, agent_id, obs_type, info={}):
        if obs_type == 'partial':
            # agent 0 has id (0 + 1)
            curr_graph = self.get_graph()
            return utils.get_visible_nodes(curr_graph, agent_id=(agent_id+1))

        elif obs_type == 'full':
            return self.get_graph()

        elif obs_type == 'visible':
            camera_ids = [self.num_static_cameras + x   for x in self.CAMERA_NUM]
            
            visible_objects = {}
            for cam_id in camera_ids:
                response = self.comm.get_visible_objects(cam_id)
                if response[1]:
                    visible_objects.update(response[1])
                break
            return visible_objects


        elif obs_type == 'image':
            if 'mode' in info:
                current_mode = info['mode']
            else:
                current_mode = 'normal'

            camera_ids = [self.num_static_cameras + x   for x in self.CAMERA_NUM]
            _, ncameras = self.comm.camera_count()
            cameras_select = list(range(ncameras))
            cameras_select = [cameras_select[x] for x in camera_ids]

            if 'image_width' in info:
                image_width = info['image_width']
                image_height = info['image_height']
            else:
                image_width, image_height = self.default_image_width, self.default_image_height

            s, images = self.comm.camera_image(cameras_select, mode=current_mode, image_width=image_width, image_height=image_height)

            images = [cv2.cvtColor(images[0], cv2.COLOR_RGB2BGR)]
            if not s:
                pdb.set_trace()

            return images
        else:
            raise NotImplementedError

    def goalCond(self, goals):

        cgc = 0

        entire_graph = self.get_graph()

        nodes,_ = self.get_nodes()
        id_to_obj = {id:obj for id,obj in nodes.items()}
        
        for idx, goal in enumerate(goals):

            if "state" in goal.keys():
                goal_states = goal['state']
                class_name = goal["class_name"]
                class_info = self.find_nodes(class_name=class_name)
                states = class_info[0]["states"]
                for state in goal_states:
                    if state in states:
                        cgc += 1
                
                
            elif "relation_type" in goal.keys():
                from_id = goal["from_id"]
                relation_type = goal["relation_type"]
                to_id = goal["to_id"]
                from_id_idx = [key for key, value in id_to_obj.items() if value == from_id]
                to_id_idx = [key for key, value in id_to_obj.items() if value == to_id]
                    
                for edge in entire_graph['edges']:
                    if edge['from_id'] in from_id_idx and edge['relation_type'] in relation_type and edge['to_id'] in to_id_idx:
                        cgc += 1
                        break
            
        return cgc/len(goals)

