''' Batched Room-to-Room navigation environment '''

import sys
sys.path.append('buildpy36')
import MatterSim
import csv
import numpy as np
import math
import base64
import utils
import json
import os
import random
import networkx as nx
from param import args

from utils import load_datasets, load_nav_graphs, Tokenizer, load_text_features, load_datasets_visual

csv.field_size_limit(sys.maxsize)


class EnvBatch():
    ''' A simple wrapper for a batch of MatterSim environments, 
        using discretized viewpoints and pretrained features '''

    def __init__(self, feature_store=None, batch_size=100):
        """
        1. Load pretrained image feature
        2. Init the Simulator.
        :param feature_store: The name of file stored the feature.
        :param batch_size:  Used to create the simulator list.
        """
        if feature_store:
            if type(feature_store) is dict:     # A silly way to avoid multiple reading
                self.features = feature_store
                self.image_w = 640
                self.image_h = 480
                self.vfov = 60
                self.feature_size = next(iter(self.features.values())).shape[-1]
                print('The feature size is %d' % self.feature_size)
        else:
            print('Image features not provided')
            self.features = None
            self.image_w = 640
            self.image_h = 480
            self.vfov = 60
        self.featurized_scans = set([key.split("_")[0] for key in list(self.features.keys())])
        self.sims = []
        for i in range(batch_size):
            sim = MatterSim.Simulator()
            sim.setRenderingEnabled(False)
            sim.setDiscretizedViewingAngles(True)   # Set increment/decrement to 30 degree. (otherwise by radians)
            sim.setCameraResolution(self.image_w, self.image_h)
            sim.setCameraVFOV(math.radians(self.vfov))
            sim.init()
            self.sims.append(sim)

    def _make_id(self, scanId, viewpointId):
        return scanId + '_' + viewpointId   

    def newEpisodes(self, scanIds, viewpointIds, headings):
        for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
            # print("New episode %d" % i)
            # sys.stdout.flush()
            self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
  
    def getStates(self):
        """
        Get list of states augmented with precomputed image features. rgb field will be empty.
        Agent's current view [0-35] (set only when viewing angles are discretized)
            [0-11] looking down, [12-23] looking at horizon, [24-35] looking up
        :return: [ ((30, 2048), sim_state) ] * batch_size
        """
        feature_states = []
        for i, sim in enumerate(self.sims):
            state = sim.getState()

            long_id = self._make_id(state.scanId, state.location.viewpointId)
            if self.features:
                feature = self.features[long_id]     # Get feature for
                feature_states.append((feature, state))
            else:
                feature_states.append((None, state))
        return feature_states

    def makeActions(self, actions):
        ''' Take an action using the full state dependent action interface (with batched input). 
            Every action element should be an (index, heading, elevation) tuple. '''
        for i, (index, heading, elevation) in enumerate(actions):
            self.sims[i].makeAction(index, heading, elevation)

    def getFeatures(self, id):
        return self.features[id]

class R2RBatch():
    ''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features '''

    def __init__(self, feature_store, batch_size=100, seed=10, splits=['train'], tokenizer=None,
                 name=None, objects=None):
        self.env = EnvBatch(feature_store=feature_store, batch_size=batch_size)
        if feature_store:
            self.feature_size = self.env.feature_size
        self.objects = objects
        self.data = []
        if tokenizer:
            self.tok = tokenizer
        scans = []
        for split in splits:
            if args.train == 'visual':
                load_func = load_datasets_visual
            else:
                load_func = load_datasets
            for item in load_func([split]):
                # Split multiple instructions into separate entries
                if item['scan'] not in self.env.featurized_scans:  # For fast training
                    continue
                if args.dataset == 'RxR':
                    new_item = dict(item)
                    new_item['instr_id'] = item['instruction_id']
                    new_item['instructions'] = item['instruction']
                    new_item['language'] = item['language'].split('-')[0]
                    if args.language == 'multi':
                        pass
                    elif args.language != new_item['language']:
                        continue
                    if args.truncate == 'first':
                        encoding_ = self.tok(item['instruction'], padding='max_length', truncation=False,
                                             max_length=args.maxInput)
                        if len(encoding_['input_ids']) > args.maxInput:
                            input_ids = [encoding_['input_ids'][0]]
                            input_ids.extend(encoding_['input_ids'][-args.maxInput+1:])
                            assert len(input_ids) == args.maxInput
                            seq_mask = [encoding_['attention_mask'][0]]
                            seq_mask.extend(encoding_['attention_mask'][-args.maxInput+1:])
                            assert len(seq_mask) == args.maxInput
                            new_item['instr_encoding'] = input_ids
                            new_item['seq_mask'] = seq_mask
                            new_item['seq_length'] = args.maxInput
                        else:
                            # encoding = self.tok(item['instruction'], padding='max_length', truncation=True, max_length=args.maxInput)
                            encoding = encoding_
                            new_item['instr_encoding'] = encoding['input_ids']
                            new_item['seq_mask'] = encoding['attention_mask']
                            new_item['seq_length'] = sum(encoding['attention_mask'])
                    elif args.truncate == 'last':
                        encoding = self.tok(item['instruction'], padding='max_length', truncation=True, max_length=args.maxInput)
                        new_item['instr_encoding'] = encoding['input_ids']
                        new_item['seq_mask'] = encoding['attention_mask']
                        new_item['seq_length'] = sum(encoding['attention_mask'])
                    elif args.truncate == 'none':
                        encoding = self.tok(item['instruction'], truncation=True, max_length=512)
                        new_item['instr_encoding'] = encoding['input_ids']
                        new_item['seq_mask'] = encoding['attention_mask']
                        new_item['seq_length'] = sum(encoding['attention_mask'])

                    if 'instructions' in item:
                        if args.con_lang == 'multi':
                            new_item['multi_instructions'] = item['instructions']
                        else:
                            new_item['multi_instructions'] = []
                            for l, ins in enumerate(item['instructions']):
                                if item['languages'][l].split('-')[0] == new_item['language']:
                                    new_item['multi_instructions'].append(ins)
                            if len(new_item['multi_instructions']) == 0:
                                continue

                    if args.train == 'visual':
                        if item['pair_path'] == item['path']:       # ignore data
                            continue
                        new_item['pair_scan'] = item['pair_scan']
                        new_item['pair_path'] = item['pair_path']
                        assert len(item['pair_path']) == len(item['path'])
                        pair_feature = []
                        for view in item['pair_path']:
                            long_id = self.env._make_id(item['pair_scan'], view)
                            pair_feature.append(self.env.getFeatures(long_id))
                        new_item['pair_feature'] = pair_feature
                        feature = []
                        for view in item['path']:
                            long_id = self.env._make_id(item['scan'], view)
                            feature.append(self.env.getFeatures(long_id))
                        new_item['feature'] = feature

                        if args.objects_constraints:
                            pair_objects = []
                            for view in item['pair_path']:
                                # print(item['pair_scan'])
                                # print(view)
                                # print(self.objects[item['pair_scan']].keys())
                                # print(self.objects[item['pair_scan']])
                                pair_objects.append(self.objects[item['pair_scan']][view])
                            new_item['pair_objects'] = pair_objects
                            objects = []
                            for view in item['path']:
                                objects.append(self.objects[item['scan']][view])
                            new_item['objects'] = objects

                    if 'path_id' not in item:
                        new_item['path_id'] = 0

                    self.data.append(new_item)
                    scans.append(item['scan'])

                else:
                    for j, instr in enumerate(item['instructions']):
                        new_item = dict(item)
                        new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
                        new_item['instructions'] = instr
                        new_item['language'] = 'en'
                        encoding = self.tok(instr, padding='max_length', truncation=True,
                                            max_length=args.maxInput)
                        new_item['instr_encoding'] = encoding['input_ids']
                        new_item['seq_mask'] = encoding['attention_mask']
                        new_item['seq_length'] = sum(encoding['attention_mask'])
                        new_item['multi_instructions'] = item['instructions']


                        self.data.append(new_item)
                        scans.append(item['scan'])
                # else:
                #     new_item['instr_encoding'], new_item['instr_features'], new_item['seq_length'] = load_text_features(split, new_item['instr_id'], new_item['language'])

        if name is None:
            self.name = splits[0] if len(splits) > 0 else "FAKE"
        else:
            self.name = name

        self.scans = set(scans)
        self.splits = splits
        self.seed = seed
        random.seed(self.seed)
        random.shuffle(self.data)

        self.ix = 0
        self.batch_size = batch_size
        self._load_nav_graphs()

        self.angle_feature = utils.get_all_point_angle_feature()
        self.sim = utils.new_simulator()
        self.buffered_state_dict = {}

        # It means that the fake data is equals to data in the supervised setup
        self.fake_data = self.data
        print('R2RBatch loaded with %d instructions, using splits: %s' % (len(self.data), ",".join(splits)))

    def size(self):
        return len(self.data)

    def _load_nav_graphs(self):
        """
        load graph from self.scan,
        Store the graph {scan_id: graph} in self.graphs
        Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in self.paths
        Store the distances in self.distances. (Structure see above)
        Load connectivity graph for each scan, useful for reasoning about shortest paths
        :return: None
        """
        print('Loading navigation graphs for %d scans' % len(self.scans))
        self.graphs = load_nav_graphs(self.scans)
        self.paths = {}
        for scan, G in self.graphs.items(): # compute all shortest paths
            self.paths[scan] = dict(nx.all_pairs_dijkstra_path(G))
        self.distances = {}
        for scan, G in self.graphs.items(): # compute all shortest paths
            self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))

    def _next_minibatch(self, tile_one=False, batch_size=None, **kwargs):
        """
        Store the minibach in 'self.batch'
        :param tile_one: Tile the one into batch_size
        :return: None
        """
        if batch_size is None:
            batch_size = self.batch_size
        if tile_one:
            batch = [self.data[self.ix]] * batch_size
            self.ix += 1
            if self.ix >= len(self.data):
                random.shuffle(self.data)
                self.ix -= len(self.data)
        else:
            batch = self.data[self.ix: self.ix+batch_size]
            if len(batch) < batch_size:
                random.shuffle(self.data)
                self.ix = batch_size - len(batch)
                batch += self.data[:self.ix]
            else:
                self.ix += batch_size
        self.batch = batch

    def reset_epoch(self, shuffle=False):
        ''' Reset the data index to beginning of epoch. Primarily for testing. 
            You must still call reset() for a new episode. '''
        if shuffle:
            random.shuffle(self.data)
        self.ix = 0

    def _shortest_path_action(self, state, goalViewpointId):
        ''' Determine next action on the shortest path to goal, for supervised training. '''
        if state.location.viewpointId == goalViewpointId:
            return goalViewpointId      # Just stop here
        path = self.paths[state.scanId][state.location.viewpointId][goalViewpointId]
        nextViewpointId = path[1]
        return nextViewpointId

    def make_candidate(self, feature, scanId, viewpointId, viewId):
        def _loc_distance(loc):
            return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
        base_heading = (viewId % 12) * math.radians(30)
        adj_dict = {}
        long_id = "%s_%s" % (scanId, viewpointId)
        if long_id not in self.buffered_state_dict:
            for ix in range(36):
                if ix == 0:
                    self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30))
                elif ix % 12 == 0:
                    self.sim.makeAction(0, 1.0, 1.0)
                else:
                    self.sim.makeAction(0, 1.0, 0)

                state = self.sim.getState()
                assert state.viewIndex == ix

                # Heading and elevation for the viewpoint center
                heading = state.heading - base_heading
                elevation = state.elevation

                visual_feat = feature[ix]

                # get adjacent locations
                for j, loc in enumerate(state.navigableLocations[1:]):
                    # if a loc is visible from multiple view, use the closest
                    # view (in angular distance) as its representation
                    distance = _loc_distance(loc)

                    # Heading and elevation for for the loc
                    loc_heading = heading + loc.rel_heading
                    loc_elevation = elevation + loc.rel_elevation
                    angle_feat = utils.angle_feature(loc_heading, loc_elevation)
                    if (loc.viewpointId not in adj_dict or
                            distance < adj_dict[loc.viewpointId]['distance']):
                        adj_dict[loc.viewpointId] = {
                            'heading': loc_heading,
                            'elevation': loc_elevation,
                            "normalized_heading": state.heading + loc.rel_heading,
                            'scanId':scanId,
                            'viewpointId': loc.viewpointId, # Next viewpoint id
                            'pointId': ix,
                            'distance': distance,
                            'idx': j + 1,
                            'feature': np.concatenate((visual_feat, angle_feat), -1)
                        }
            candidate = list(adj_dict.values())
            self.buffered_state_dict[long_id] = [
                {key: c[key]
                 for key in
                    ['normalized_heading', 'elevation', 'scanId', 'viewpointId',
                     'pointId', 'idx']}
                for c in candidate
            ]
            return candidate
        else:
            candidate = self.buffered_state_dict[long_id]
            candidate_new = []
            for c in candidate:
                c_new = c.copy()
                ix = c_new['pointId']
                normalized_heading = c_new['normalized_heading']
                visual_feat = feature[ix]
                loc_heading = normalized_heading - base_heading
                c_new['heading'] = loc_heading
                angle_feat = utils.angle_feature(c_new['heading'], c_new['elevation'])
                c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1)
                c_new.pop('normalized_heading')
                candidate_new.append(c_new)
            return candidate_new

    def _get_obs(self):
        obs = []
        for i, (feature, state) in enumerate(self.env.getStates()):
            item = self.batch[i]
            base_view_id = state.viewIndex

            if args.train == 'mbert':
                obs.append({
                    'instr_id': item['instr_id'],
                    'scan': state.scanId,
                    'viewpoint': state.location.viewpointId,
                    'viewIndex': state.viewIndex,
                    'heading': state.heading,
                    'elevation': state.elevation,
                    'instructions': item['instructions'],
                    'language': item['language']
                })
            elif args.train == 'visual':
                # Full features
                candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)

                # (visual_feature, angel_feature) for views
                # path_index = item['path'].index(state.location.viewpointId)
                # pair_feature = []
                # for view in item['pair_path']:
                #     long_id = self.env._make_id(item['pair_scan'], view)
                #     pair_feature.append(self.env.getFeatures(long_id))
                obs.append({
                    'instr_id': item['instr_id'],
                    'scan': state.scanId,
                    'viewpoint': state.location.viewpointId,
                    'viewIndex': state.viewIndex,
                    'heading': state.heading,
                    'elevation': state.elevation,
                    'feature': item['feature'],
                    'pair_feature': item['pair_feature'],
                    'candidate': candidate,
                    'navigableLocations': state.navigableLocations,
                    'instructions': item['instructions'],
                    'teacher': self._shortest_path_action(state, item['path'][-1]),
                    'language': item['language'],
                    'path': item['path']
                })

                if args.objects_constraints:
                    if args.random_objects:
                        indexes = [i for i in range(len(item['objects'][0]))]
                        samples = np.random.choice(indexes, 10, replace=False)
                        objects = []
                        pair_objects = []
                        for j in range(len(item['objects'])):
                            objects.append([item['objects'][j][k] for k in samples])
                            pair_objects.append([item['pair_objects'][j][k] for k in samples])
                        obs[-1]['objects'] = item['objects']
                        obs[-1]['pair_objects'] = item['pair_objects']
                    else:
                        obs[-1]['objects'] = item['objects']
                        obs[-1]['pair_objects'] = item['pair_objects']
            else:
                # Full features
                candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)

                # (visual_feature, angel_feature) for views
                feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
                obs.append({
                    'instr_id': item['instr_id'],
                    'scan': state.scanId,
                    'viewpoint': state.location.viewpointId,
                    'viewIndex': state.viewIndex,
                    'heading': state.heading,
                    'elevation': state.elevation,
                    'feature': feature,
                    'candidate': candidate,
                    'navigableLocations': state.navigableLocations,
                    'instructions': item['instructions'],
                    'teacher': self._shortest_path_action(state, item['path'][-1]),
                    'language': item['language'],
                    'path': item['path']
                })
            if 'instr_encoding' in item:
                obs[-1]['instr_encoding'] = item['instr_encoding']
            if 'instr_features' in item:
                obs[-1]['instr_features'] = item['instr_features']
            if 'seq_length' in item:
                obs[-1]['seq_length'] = item['seq_length']
            if 'seq_mask' in item:
                obs[-1]['seq_mask'] = item['seq_mask']
            if 'path_id' in item:
                obs[-1]['path_id'] = item['path_id']
            # A2C reward. The negative distance between the state and the final state
            obs[-1]['distance'] = self.distances[state.scanId][state.location.viewpointId][item['path'][-1]]

            # paired instruction for contrastive loss
            if args.contrastive:
                index = random.randint(0, len(item['multi_instructions']) - 1)
                ins = item['multi_instructions'][index]
                obs[-1]['paired_ins'] = ins
                encoding = self.tok(ins, padding='max_length', truncation=True, max_length=args.maxInput)
                obs[-1]['paired_encoding'] = encoding['input_ids']
                obs[-1]['paired_mask'] = encoding['attention_mask']
                obs[-1]['paired_length'] = sum(encoding['attention_mask'])

        return obs

    def reset(self, batch=None, inject=False, **kwargs):
        ''' Load a new minibatch / episodes. '''
        if batch is None:       # Allow the user to explicitly define the batch
            self._next_minibatch(**kwargs)
        else:
            if inject:          # Inject the batch into the next minibatch
                self._next_minibatch(**kwargs)
                self.batch[:len(batch)] = batch
            else:               # Else set the batch to the current batch
                self.batch = batch
        scanIds = [item['scan'] for item in self.batch]
        viewpointIds = [item['path'][0] for item in self.batch]
        headings = [item['heading'] for item in self.batch]
        self.env.newEpisodes(scanIds, viewpointIds, headings)
        return self._get_obs()

    def step(self, actions):
        ''' Take action (same interface as makeActions) '''
        self.env.makeActions(actions)
        return self._get_obs()

    def get_statistics(self):
        stats = {}
        length = 0
        path = 0
        for datum in self.data:
            length += len(self.tok.split_sentence(datum['instructions']))
            path += self.distances[datum['scan']][datum['path'][0]][datum['path'][-1]]
        stats['length'] = length / len(self.data)
        stats['path'] = path / len(self.data)
        return stats


