from collections import defaultdict
from random import choice
from . import utils


def terminal(is_terminal):
    """
    Decorator for handler functions, to specify if function can be terminal or not.
    """
    def wrapper(func):
        func.terminal = is_terminal
        return func
    return wrapper


class Handler:
    def __init__(self):
        self._internal_dict = dict()

    def get_answer(self, scene, program_tree, token_assignment, verbose=False):
        """
        Evaluate an expression tree, as ['func_name', [arg1, arg2, ...]] recursively.
        """
        assert type(program_tree) in [tuple, list], TypeError(f'Received type {type(program_tree)} instead of tuple.')
        # assert len(program_tree) == 2, AssertionError('Format error. Malformed branch.')

        root, args = program_tree[0], program_tree[1:]
        program_args = []
        for arg in args:
            if type(arg) in [tuple, list]:
                val = self.get_answer(scene, arg, token_assignment)
            elif arg == 'scene':
                val = scene
            else:
                assert type(arg) in [str, type(None)], TypeError(f'Argument type not tuple or string. '
                                                                 f'Received: {type(arg)}')
                val = token_assignment[arg] if arg in token_assignment else None
            program_args.append(val)
        if verbose:
            print(f'Evaluating {root}(*{program_args}).')

        return getattr(self, root)(*program_args)

    def get_category_attrs(self, scene, category, obj_id):
        candidates = list()
        for attr in scene['objects'][obj_id]['attributes']:
            if category in utils.get_attr_categories(attr):
                candidates.append(attr)
        return candidates

    def get_obj_name(self, scene, obj_id):
        return scene['objects'][obj_id]['name']

    def select(self, sequence):
        if len(sequence) == 0:
            return None
        else:
            return choice(sequence)

    @terminal(True)
    def equal(self, val1, val2):
        return val1 == val2

    @terminal(True)
    def member_of(self, element, in_set):
        return element in in_set

    @terminal(True)
    def exists(self, scene, obj_name, attributes=None):
        """
        Returns true if there is an object in scene with name == obj_name, and its attributes are a superset of
        attributes.
        """
        attributes = set(attributes) if attributes else set()
        for obj_id, obj in scene['objects'].items():
            if obj_name in utils.get_hypernyms(obj['name']):
                if attributes.issubset(frozenset(obj['attributes'])):
                    return True
        return False

    @terminal(True)
    def logical_and(self, bool1, bool2):
        assert type(bool1) == bool and type(bool2) == bool, AssertionError('Arguments must be boolean values.')
        return bool1 and bool2

    @terminal(True)
    def logical_or(self, bool1, bool2):
        assert type(bool1) == bool and type(bool2) == bool, AssertionError('Arguments must be boolean values.')
        return bool1 or bool2

    @terminal(True)
    def logical_not(self, val):
        assert type(val) is bool, AssertionError('Argument must be boolean value.')
        return not val

    @terminal(True)
    def rel_exists(self, scene, obj1_name, attributes1, relation, obj2_name, attributes2):
        """
        Returns true if there is an object in scene with name == obj_name, and its attributes are a superset of
        attributes, and relation exists.
        """
        obj1_candidates = list()
        attributes1 = set(attributes1)
        attributes2 = set(attributes2)
        obj1_candidates = [x[0] for x in filter(lambda x: obj1_name in utils.get_hypernyms(x[1]['name']) and
                                                          attributes1.issubset(x[1]['attributes']),
                                                scene['objects'].items())]
        # for obj_id, obj in scene['objects'].items():
        #     if obj1_name in utils.get_hypernyms(obj['name']) and attributes1.issubset(frozenset(obj['attributes'])):
        #         obj1_candidates.append(obj_id)
        for obj_id in obj1_candidates:
            for rel in scene['objects'][obj_id]['relations']:
                if rel['name'] == relation:
                    rel_obj = scene['objects'][rel['object']]
                    if obj2_name in utils.get_hypernyms(rel_obj['name']) and attributes2.issubset(frozenset(rel_obj['attributes'])):
                        return True
        return False

    @terminal(True)
    def convert_bool(self, val):
        return {True: 'yes', False: 'no'}[val]

    @terminal(True)
    def count(self, scene, obj_name, attributes):
        """
        Returns the number (int) of obj_name with attributes found in scene.
        """
        count = 0
        attributes = set(attributes)
        for obj_id, obj_data in scene['objects'].items():
            if obj_name in utils.get_hypernyms(obj_data['name']) and attributes.issubset(set(obj_data['attributes'])):
                count += 1
        return count

    @terminal(True)
    def color(self, scene, obj_name, attributes):
        """
        Returns unique color of object.
        """
        colors = list(self.colors(scene, obj_name, attributes))
        if len(colors) == 1:
            return colors[0]
        # assert len(color) == 1, ValueError(f'No color attribute found for object: {attributes} {obj_name}.')
        else:
            return None

    @terminal(False)
    def colors(self, scene, obj_name, attributes):
        """
        Returns all colors of object.
        """
        attr_map = utils.get_attribute_map(scene['objects'])
        objs = attr_map[(obj_name, attributes)]
        if len(objs) != 1:
            return frozenset()
            # raise ValueError(f'FIXME: Constraints enforced for color.')
        obj_id = list(objs)[0]
        obj = scene['objects'][obj_id]
        return utils.get_color(obj)

    @terminal(False)
    def filter_rel(self, scene, obj1_name, attributes1, relation, obj2_name, attributes2):
        obj1_candidates = list()
        attributes1 = set(attributes1) if attributes1 else set()
        attributes2 = set(attributes2) if attributes2 else set()
        for obj_id, obj in scene['objects'].items():
            if obj1_name in utils.get_hypernyms(obj['name']) and attributes1.issubset(frozenset(obj['attributes'])):
                obj1_candidates.append(obj_id)
        final_candidates = list()
        for obj_id in obj1_candidates:
            for rel in scene['objects'][obj_id]['relations']:
                if rel['name'] == relation:
                    rel_obj = scene['objects'][rel['object']]
                    if obj2_name in utils.get_hypernyms(rel_obj['name']) and attributes2.issubset(frozenset(rel_obj['attributes'])):
                        final_candidates.append(obj_id)
        return final_candidates

    @terminal(False)
    def filter(self, scene, obj_name, attributes):
        obj_candidates = list()
        attributes = set(attributes) if attributes else set()
        for obj_id, obj in scene['objects'].items():
            if obj_name in utils.get_hypernyms(obj['name']) and attributes.issubset(frozenset(obj['attributes'])):
                obj_candidates.append(obj_id)
        return obj_candidates

    @terminal(False)
    def related_objs(self, scene, obj_name, rel_name):
        rel_map = utils.get_relations_map(scene['objects'])
        objs = rel_map[(obj_name, frozenset([rel_name]))]
        if len(objs) != 1:
            raise ValueError(f'FIXME: Constraints enforced for related_objs.')
        obj_id = list(objs)[0]
        relations = list(filter(lambda x: x['name'] == rel_name, scene['objects'][obj_id]['relations']))

        rel_objs_ids = list(map(lambda x: x['object'], relations))
        rel_objs = {'objects': {idx: scene['objects'][idx] for idx in rel_objs_ids}}
        return rel_objs
