"""
Inquiry types:
    Existence,
    Numeracy,
    Spatial Reasoning,
    Absolute spatial reasoning (with respect to the image dimensions (ie. is there a house in the bottom left corner?)),
    Compositionality,
    Taxonomy,
    Vocabulary,

"""
import re
from copy import deepcopy
from math import inf
from functools import reduce
from random import sample, choices, choice
from collections import defaultdict
from itertools import combinations, product
from graphlib import TopologicalSorter, CycleError
from lemminflect import getAllInflections, getAllInflectionsOOV, getLemma
from .program_handlers import Handler
from .constraint_handler import ConstraintHandler
from . import utils


class Scene:
    def __init__(
            self,
            scene,
            include_ids,
            img_id
    ):
        self.img_id = img_id
        self.scene = scene
        self.objects = self.scene['objects']
        self.object_names = {x['name'] for x in self.objects.values()}
        self.constraint_handler = ConstraintHandler(self.objects)
        self.include_ids = include_ids

    def get_relation_ids(self, obj):
        """
        params:
            obj: **object_id** in self.objects
        returns:
            generator with object id's from obj's relations.
        """
        for k in self.objects[obj]['relations']:
            yield k['object']

    def _get_key_list(self, token_set):
        """
        Returns a list with determinsitic order (sorted) that will be used in the token assignment.
        """
        key_list = list()
        object_key_pattern = re.compile(utils.OBJECT_TOKEN)
        for token_id, tokens in token_set.items():
            for token in tokens:
                key = f'{token}{token_id}'
                key_list.append(key)
        if self.include_ids:
            id_keys = list()
            for key in key_list:
                match = object_key_pattern.match(key)
                if match:
                    id_keys.append(f'{key}_id')
            key_list += id_keys
        return list(sorted(key_list))

    def _build_assignment_tree(self, valid_object_ids, token_idx_order, current_assignment, parents_graph):
        """
        Generates a tree where each path represents a possible assignment of token_ids to object_ids.
        [token_id_1] - (object_id_1) - [token_id_2] - (object_id_2)
                     ...            \ [token_id_2] - (obejct_id_3)
                                ...
        The ordering of token_ids provided by token_idx_order ensures that we'll always have an assignment in
        current_assignment for any token that depends on a parent's assignment. parents_graph keeps track of parent
        dependencies when we need to restrict the possible object_ids for a token (when it is the subject in a
        relationship with a parent's token).
        """
        head, rest_token_idx_order = token_idx_order[0], token_idx_order[1:]
        if head in parents_graph:
            parents = parents_graph[head]
            vois = reduce(lambda x, y: x & set(self.get_relation_ids(y)),
                          list(map(current_assignment.get, parents)),
                          valid_object_ids)
        else:
            vois = valid_object_ids
        vois = choices(list(vois), k=len(vois))
        for object_id in vois:
            current_assignment[head] = object_id
            if rest_token_idx_order:
                yield from self._build_assignment_tree(valid_object_ids-{object_id},
                                                       rest_token_idx_order,
                                                       current_assignment,
                                                       parents_graph)
            else:
                yield deepcopy(current_assignment)

    def assign_tokens(self, token_set, constraints, token_idx_order, parents_graph, max_num=inf):
        """
        Returns a list of all possible (valid) token assignments (dictionaries).
        """
        if token_idx_order is None:
            token_idx_order = list(sorted(token_set.keys()))
        valid_object_keys = {k for k, v in self.objects.items() if v['name'] not in utils.abstract_objects}
        expected_length = len(token_idx_order)
        all_assignments = set()
        key_list = self._get_key_list(token_set)  # Set order to absorb duplicates with tuples instead of dicts
        for assignment in self._build_assignment_tree(valid_object_keys, token_idx_order, dict(), parents_graph):
            if len(assignment) == len(set(assignment.values())) == expected_length:  # verifies assignment is valid
                current_assignment = dict()
                for token_id, tokens in token_set.items():
                    object_id = assignment[token_id]
                    for token in sorted(tokens):
                        key = f'{token}{token_id}'
                        current_assignment.update(self.get_assignment_val(token, assignment, object_id, key, current_assignment))
                assignment_values = product(*[current_assignment[key] for key in key_list])  # Take product wrt key_list
                all_assignments |= set(assignment_values)  # absorb duplicates using a set (and tuples)
                # TODO: would accepting only those assignment values that have no duplicates reduce question ambiguity?
        all_assignments = list(map(lambda x: dict(zip(key_list, x)), all_assignments))  # regenerate dict assignments
        all_assignments = list(filter(lambda x: self.constraint_handler.check_constraints(constraints, x), all_assignments))
        return all_assignments
        #     if singular:
        #         assignment = dict()
        #         for key, val in assignments.items():
        #             if val:
        #                 assignment[key] = choice(val)
        #             else:
        #                 assignment[key] = val
        #         if self.constraint_handler.check_constraints(constraints, assignment):
        #             all_assignments.add(tuple(assignment.items()))
        #     else:
        #         assignment_values = product(*[assignments[key] for key in key_list])
        #         for assignment in assignment_values:
        #             # new_key_list, new_assignment = self.assign_functionals(func_token_set, dict(zip(key_list, assignment)))
        #             if self.constraint_handler.check_constraints(constraints, dict(zip(key_list, assignment))):
        #                 all_assignments.add(tuple(zip(key_list, assignment)))

    def get_assignment_val(self, token, token_id_assignment, obj_id, key, current_assignment, singular=False):
        obj = self.objects[obj_id]
        assignment = dict()
        if token == 'obj':
            val = [obj['name'], ]
            if self.include_ids:
                assignment[f'{key}_id'] = [obj_id, ]  # track object id
        elif token == 'attrs':
            if singular:
                val = utils.powerset(utils.get_attribute_with_prob(obj['attributes']))
            else:
                val = utils.powerset(obj['attributes'])
        elif re.search(utils.RELATION_TOKEN, key):  # ignore token, use key instead
            rel_object_idx, base_object_idx = re.search(utils.RELATION_TOKEN, key).groups()
            rel_obj_id, base_object_id = token_id_assignment[rel_object_idx], token_id_assignment[base_object_idx]
            val = set(map(lambda x: x['name'], list(filter(lambda x: x.get('object') == rel_obj_id, self.objects[base_object_id]['relations']))))
        elif token == 'color':
            val = utils.get_color(obj)  # Not wrapping instantiates the product of assignment values
        elif token == '!color':
            val = utils.get_random_color(obj, exclude=utils.get_color(obj))
        elif token == 'hypernym':
            val = utils.minimum_hypernyms[obj['name']] if obj['name'] in utils.minimum_hypernyms else []
            obj_num = re.match(utils.TOKEN, key).groups()[3]
            if self.include_ids:
                assignment[f'obj{obj_num}_id'] = [obj_id, ]  # track object id
        elif token == 'obj-category':
            val = utils.get_object_categories(obj['name'])
        elif token == 'obj-category-options':
            val = list()
            obj_num = re.match(utils.TOKEN, key).groups()[3]
            categories = current_assignment[f'obj-category{obj_num}']
            for category in categories:
                k = choice([2, 3])
                tup = tuple(utils.create_obj_category_options(category, obj['name'], k=k))
                if len(tup) == k:
                    val.append(tup)
        elif token == 'category':
            val = list()
            cat2attrs = utils.get_category_attributes_map(obj['attributes'])
            for cat, attrs in cat2attrs.items():
                if len(attrs) == 1:  # we don't want there to be more than one correct answer for attribute choices
                    val.append(cat)
        elif token == 'category-options':  # since category-options > category, category will have been assigned.
            val = list()
            obj_num = re.match(utils.TOKEN, key).groups()[3]
            categories = current_assignment[f'category{obj_num}']
            cat2attrs = utils.get_category_attributes_map(obj['attributes'])
            for category in categories:
                assert len(cat2attrs[category]) == 1
                attr = cat2attrs[category][0]
                k = choice([2, 3])
                tup = tuple(utils.create_attr_category_options(category, attr, obj['name'], k=k))
                if len(tup) == k:
                    val.append(tup)
        else:
            raise ValueError(f'Token name unknown: {token}')
        assignment[key] = val
        return assignment


class QGenerator:
    text_token_pattern = re.compile(utils.TEXT_TOKEN)
    modifier_pattern = re.compile(utils.TEXT_MODIFIERS)

    def __init__(self, question_type=None):
        self.question_type = question_type
        self.handler = Handler()
        self.modifier_dict = {'DET': self.get_determiner, 'Is': self.inflect_is}
        self._plural = dict()  # cache for plural words used in self._is_plural

    def generate_assignments(self, scene, tokens, constraints, img_id, max_num=inf, include_ids=True):
        token_sets = defaultdict(set)
        token_graph = dict()  # keep track of object relations to assign objects in efficient ordering
        parents = defaultdict(set)
        for token in tokens:
            if token != 'scene':
                match = re.match(utils.TOKEN, token)
                assert len(match.groups()) == 4, AssertionError(f'Malformed input token.')
                rand, rel_object_number, token_type, object_number = match.groups()
                if object_number not in token_graph:
                    token_graph[object_number] = set()
                token_sets[object_number].add(rand + rel_object_number + token_type)
                if rel_object_number:
                    token_graph[object_number].add(rel_object_number)
                    parents[rel_object_number].add(object_number)
        try:
            token_idx_order = list(
                reversed(list(TopologicalSorter(token_graph).static_order())))  # sort to resolve relation dependencies
        except CycleError:
            raise ValueError(f'Received cyclical relation tokens. Cannot perform a token assignment. '
                             f'Please check the question family tokens to resolve cyclical references.')
        assert set(token_idx_order) == set(token_sets.keys())
        scene = Scene(scene, include_ids, img_id)
        return scene.assign_tokens(token_sets, constraints, token_idx_order, parents_graph=parents, max_num=max_num)

    @staticmethod
    def expand_text_template(template, assignment):
        def clean(string):
            return ' '.join(string.split())

        pattern = re.compile(utils.TEXT_TOKEN)
        text = template
        for token, token_name, obj_num in pattern.findall(template):
            value = assignment[token_name]
            if type(value) in [frozenset, tuple, set, list]:  # If value is a set of attributes.
                value = ' '.join(value)
            elif type(value) != str:
                value = str(value)  # for some numeric assignments
            text = text.replace(token, value)
        return clean(text)

    def get_determiner(self, word):
        if word in {'ground', 'floor', 'ceiling', 'sky'}:
            return 'the'
        elif self._is_plural(word) or word in {'clothes', 'clothing', 'people', 'food'}:
            return 'any'
        elif word[0] in {'a', 'e', 'i', 'o', 'u'}:
            return 'an'
        else:
            return 'a'

    @staticmethod
    def _get_inflections(word):
        lemma = getLemma(word, upos='NOUN')[0]
        inflections = getAllInflections(lemma, upos='NOUN')
        if not {'NN', 'NNS'}.issubset(set(inflections.keys())):
            inflections = getAllInflectionsOOV(word, upos='NOUN')
        return inflections

    def _is_plural(self, word):
        if word not in self._plural:
            inflections = self._get_inflections(word)
            if word in inflections['NN']:
                self._plural[word] = False
            elif word in inflections['NNS']:
                self._plural[word] = True
            else:
                self._plural[word] = False
        return self._plural[word]

    def inflect_is(self, word):
        if self._is_plural(word):
            return 'are'
        else:
            return 'is'

    def _add_determiners(self, values):
        output = ''
        for ix, word in enumerate(values):
            if ix == (len(values) - 1) and ix > 0:
                output += ', or'
            elif ix > 0:
                output += ','
            if self._is_plural(word):
                output += ' ' + word
            elif word[0] in {'a', 'e', 'i', 'o', 'u'}:
                output += f' an {word}'
            else:
                output += f' a {word}'
        return output

    def _enumerate_attributes(self, attrs):
        if len(attrs) <= 1:
            return ' '.join(attrs)
        else:
            string = ', '.join(attrs[:-1])
            string += ' and ' + attrs[-1]
            return string

    def expand_text_template_multi(self, template, assignment):
        modifier_groups = defaultdict(set)
        for obj_num, modifier in self.modifier_pattern.findall(template):
            modifier_groups[obj_num].add(modifier)

        text = template
        for token, token_name, obj_num in self.text_token_pattern.findall(template):
            value = assignment[token_name]
            if 'obj-category-options' in token_name and type(value) in [tuple, list, set, frozenset]:
                value = self._add_determiners(value)
            elif '-options' in token_name and type(value) in [tuple, list, set, frozenset]:
                value = ', '.join(value[:-1]) + ' or ' + value[-1]
            if token_name.startswith('attrs') and '_by_attr_' in self.question_type and len(value) > 1:
                value = self._enumerate_attributes(value)
            elif type(value) in [frozenset, tuple, set, list]:  # If value is a set of attributes.
                value = ' '.join(value)
            elif type(value) != str:
                value = str(value)  # for some numeric assignments
            value = ' '.join(value.split('_'))
            text = text.replace(token, value)
            if value:
                for modifier in modifier_groups[obj_num]:
                    key = f'[{obj_num}:{modifier}]'
                    mod_value = self.modifier_dict[modifier](value)
                    text = text.replace(key, mod_value)

        return ' '.join(text.split()).capitalize()
