import re
from . import utils


def terminal(is_terminal):
    """
    Decorator for handler functions, to specify if function can be terminal or not.
    """
    def wrapper(func):
        func.terminal = is_terminal
        return func
    return wrapper


class ConstraintHandler:
    def __init__(self, objects):
        self.objects = objects

    def _get_bbox_sizes(self, obj_id=None):
        return (self.objects[obj_id]['w'], self.objects[obj_id]['h'])

    def check_constraints(self, constraints, assignment):
        for constraint in constraints:
            func, argstr = re.match(r'^([a-zA-Z]\w+)\((.*)\)$', constraint).groups()
            args = list(map(str.strip, argstr.split(',')))
            func_args = [assignment[a] if a in assignment.keys() else (a if a != 'None' else None) for a in args]
            if not getattr(self, func)(*func_args):
                return False
        return True

    def count(self, obj, attrs):
        if attrs is None:
            attrs = set()
        else:
            attrs = set(attrs)
        detected = 0
        for obj_i in self.objects.values():
            if obj in utils.get_hypernyms(obj_i['name']) and attrs.issubset(set(obj_i['attributes'])):
                detected += 1
        return detected

    def min_count(self, obj, attrs, min_count):
        count = self.count(obj, attrs)
        return count >= int(min_count)

    def max_count(self, obj, attrs, min_count):
        count = self.count(obj, attrs)
        return count <= int(min_count)

    def min_width(self, obj_id, minimum):
        width, _ = self._get_bbox_sizes(obj_id)
        return width >= int(minimum)

    def min_height(self, obj_id, minimum):
        _, height = self._get_bbox_sizes(obj_id)
        return height >= int(minimum)

    def unique(self, obj, attrs=None):
        detected = self.count(obj, attrs)
        return detected == 1

    def not_unique(self, obj, attrs=None):
        detected = self.count(obj, attrs)
        return detected > 1

    def unique_rel(self, obj1, attrs1, rel, obj2, attrs2):
        attrs1 = set() if attrs1 is None else set(attrs1)
        attrs2 = set() if attrs2 is None else set(attrs2)
        detected = 0
        for obj_i in filter(lambda x: obj1 in utils.get_hypernyms(x['name']) and attrs1.issubset(set(x['attributes'])), self.objects.values()):
            for relation in filter(lambda x: x['name'] == rel, obj_i['relations']):
                rel_object = self.objects[relation['object']]
                if obj2 in utils.get_hypernyms(rel_object['name']) and attrs2.issubset(set(rel_object['attributes'])):
                    detected += 1
                    if detected >= 2:
                        return False
        return detected == 1

    def restrict_category_object(self, category, restricted_category, obj, *args):
        if category == restricted_category:
            return obj in args
        return True

    def mutex_attr_category(self, attrs, category):
        for attr in attrs:
            if category in utils.get_attr_categories(attr):
                return False
        return True

    def mutex_obj_category(self, attrs, category):
        for attr in attrs:
            if category in utils.get_attr_categories(attr):
                return False
        return True

    def member_of_attr_category(self, attr, category):
        return attr in utils.train_category_attributes[category]

    def all_member_of_attr_category(self, attrs, category):
        return set(attrs).issubset(utils.train_category_attributes[category])

    def not_member_of_attr_category(self, attr, category):
        return attr not in utils.train_category_attributes[category]

    def mutex_obj_category(self, attrs, category):
        for attr in attrs:
            if category in utils.get_object_categories(attr):
                return False
        return True

    def member_of_obj_category(self, attr, category):
        return attr in utils.category_objects[category]

    def all_member_of_obj_category(self, attrs, category):
        return set(attrs).issubset(utils.category_objects[category])

    def not_member_of_obj_category(self, attr, category):
        return attr not in utils.category_objects[category]

    def member_of(self, key, *args):
        return key in args

    def not_member_of(self, key, *args):
        return not self.member_of(key, *args)

    def max_length(self, attrs, max_len):
        return len(attrs) <= int(max_len)

    def min_length(self, seq, min_len):
        return len(seq) >= int(min_len)

    def eq_length(self, attrs, length):
        return len(attrs) == int(length)

    def exclude_color(self, attrs):
        return len(set(attrs).intersection(utils.train_colors)) == 0

    def exclude_attrs(self, attrs, *args):
        if attrs is None:
            attrs = set()
        return set(attrs).intersection(args) == set()

    def exclude_ambiguous(self, obj):
        return obj not in utils.abstract_objects

    def not_equal(self, val1, val2):
        return val1 != val2

    def equal(self, val1, val2):
        return val1 == val2

    def symmetric_exclude_with_hypernyms(self, obj1, obj2):
        return obj1 not in utils.get_inferential_hypernyms(obj2) and obj2 not in utils.get_inferential_hypernyms(obj1)
