# from pattern.en import wordnet
# from math import inf

from .re_patterns import *

import os
import json
import numpy as np
import random
random.seed(42)
from collections import defaultdict, OrderedDict
from itertools import chain, combinations, product


def invert_dictionary(orig):
    new_dict = defaultdict(list)
    for key, values in orig.items():
        for val in values:
            new_dict[val].append(key)
    return dict(new_dict.items())


ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..')
abstract_objects = {'air', 'sky', 'paint', 'ceiling', 'floor', 'tag', 'decoration', 'ornament', 'character', 'wall',
                    'room', 'ground', 'kitchen', 'dirt', 'grass', 'frame', 'road', 'field', 'building'}
attribute_vocabulary = set(json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_sceneGraphs_attributes_vocab.json'))))
train_attribute_antonyms = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_attribute_antonyms.json')))
train_category_attributes = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'attribute_categories.json')))
train_attribute_categories = invert_dictionary(train_category_attributes)
attr_category_2_obj_attr_counts = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'group2obj_attribute_counts.json')))
attr_dist_and_default = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_attr_dist.json')))
attribute_distribution = defaultdict(lambda: attr_dist_and_default['default_prob'])
attribute_distribution.update(attr_dist_and_default['distribution'])
train_object_counts = defaultdict(int)
train_object_counts.update(json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_object_counts.json'))))
object_distribution = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_obj_dist.json')))
object_attribute_distributions = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_obj_attr_pair_dists.json')))
object_relation_object_distributions = json.load(open(os.path.join(ROOT_DIR, 'vocabs',
                                                                   'train_obj_rel_obj_triple_dists.json')))
object_object_dist_dictionary = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_obj_obj_dist_add_one_smoothed.json')))
object_object_distributions = object_object_dist_dictionary['distributions']
object_default_probability = defaultdict(lambda: float, object_object_dist_dictionary['default_probs'])
object_vocabulary = set(json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_sceneGraphs_objects_custom_vocab.json'))))
category_objects = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'category_objects.json')))
object_categories = invert_dictionary(category_objects)
object_hypernyms = {k: set(v) for k, v in json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_object_vocab_hypernyms_extended.json'))).items()}
object_hyponyms = {k: set(v) for k, v in json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_object_vocab_hyponyms_extended.json'))).items()}
minimum_hypernyms = {k: set(v) for k, v in json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_objects_minimum_hypernyms.json'))).items()}
minimum_hyponyms = {k: set(v) for k, v in json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_objects_minimum_hyponyms.json'))).items()}
minimum_ontology_vocabulary = set(minimum_hypernyms) | set(minimum_hyponyms)
object_inferential_hypernyms = json.load(open(os.path.join(ROOT_DIR, 'vocabs',
                                                           'train_object_vocab_hypernyms_inferential.json')))
train_obj_colors = json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_obj_color_pair_dists.json')))
train_object_object_relation_pairs = {k: {eval(kp): vp for kp, vp in v.items()} for k, v in json.load(open(os.path.join(ROOT_DIR, 'vocabs', 'train_object_object_relation_pairs.json'))).items()}


def powerset(x_iterable, max_depth=None):
    x_set = set(x_iterable)
    if max_depth is None:
        max_depth = len(x_set)
    return list(chain.from_iterable([map(lambda x: tuple(sorted(x)),
                                         (combinations(x_set, r))) for r in range(max_depth + 1)]))


def get_attribute_map(obj_set):
    attr_obj_map = defaultdict(set)
    for obj_id, obj_data in obj_set.items():
        name = obj_data['name']  # string
        all_attrs = set(obj_data['attributes'])
        for attrs in powerset(all_attrs):
            key = (name, attrs)  # (obj_name, {green, shiny, ...})
            attr_obj_map[key].add(obj_id)
    # return dict(attr_obj_map)  # TODO: Do we need a proper dict? (not defaultdict)
    return attr_obj_map


def get_obj_ids(obj_dict, obj_name, obj_attrs):
    obj_ids = set()
    for obj_id, obj_data in obj_dict.items():
        if obj_name == obj_data['name'] and obj_attrs.issubset(set(obj_data['attributes'])):
            obj_ids.add(obj_id)
    return obj_ids


def get_count(obj_dict, obj_name, obj_attrs):
    count = 0
    for obj_id, obj_data in obj_dict.items():
        if obj_name == obj_data['name'] and obj_attrs.issubset(set(obj_data['attributes'])):
            count += 1
    return count


def get_attr_categories(attr):
    if attr in train_attribute_categories:
        return train_attribute_categories[attr]
    else:
        return []


def get_category_attributes_map(attrs):
    cat2attrs = defaultdict(list)
    for attr in attrs:
        for cat in get_attr_categories(attr):
            cat2attrs[cat].append(attr)
    return dict(cat2attrs.items())


def get_object_categories(obj_name):
    if obj_name in object_categories:
        return object_categories[obj_name]
    else:
        return []


def create_obj_category_options(category, obj, k):
    return _create_options(category, obj, obj_attr=None, k=k, category_choices=category_objects, counts=train_object_counts)


def create_attr_category_options(category, obj_attr, obj, k):
    cat_obj_counts = attr_category_2_obj_attr_counts[category][obj] if obj in attr_category_2_obj_attr_counts[category] else \
        defaultdict(int)
    return _create_options(category, obj, obj_attr, k=k, category_choices=train_category_attributes, counts=cat_obj_counts)


def weighted_shuffle(items, weights):
    # order = sorted(range(len(items)), key=lambda i: random.random() ** (1.0 / weights[i]))
    order = sorted(range(len(items)), key=lambda i: random.random() * weights[i])
    return [items[i] for i in order]


def smoothed_weights(values, counts, plus=1):
    weights = list(map(lambda x: counts.get(x, 0) + plus, values))
    total = sum(weights) if len(weights) > 0 else 1
    return list(map(lambda x: x/total, weights))


def _create_options(category, obj, obj_attr, k, category_choices, counts):
    assert obj is not None
    if obj_attr is None:
        base_choice = obj
    else:
        base_choice = obj_attr
    choices = {base_choice, }
    eliminated = set(category_choices[category][base_choice])
    unseen = set(category_choices[category]) - choices
    unseen -= eliminated
    unseen = list(unseen)
    while len(choices) < 3 and len(unseen) > 0:
        candidate = random.choices(unseen, weights=list(map(lambda x: counts.get(x, 1), unseen)))[0]
        unseen = list(set(unseen) - {candidate, })
        if candidate in choices or candidate in eliminated:
            continue
        choices.add(candidate)
        eliminated |= set(category_choices[category][candidate])
        if len(choices) >= k:
            break
    choices = list(choices)
    random.shuffle(choices)
    return choices


def get_hypernyms(word):
    if word in object_hypernyms:
        return object_hypernyms[word]
    else:
        return {word, }


def get_hyponyms(word):
    if word in object_hyponyms:
        return object_hyponyms[word]
    else:
        return {word, }


def get_inferential_hypernyms(word):
    if word in object_inferential_hypernyms:
        return object_inferential_hypernyms[word]
    else:
        return {word, }


def get_random_count(exclude):
    if type(exclude) not in [set, frozenset, list]:
        exclude = frozenset([exclude, ])
    count = random.choice(list(filter(lambda x: x not in exclude, range(1, 16))))
    return count


def get_relations_map(obj_set):
    rel_obj_map = defaultdict(set)
    for obj_id, obj_data in obj_set.items():
        name = obj_data['name']  # string
        all_rels = set(map(lambda x: x.get('name'), obj_data['relations']))
        for rels in powerset(all_rels):
            key = (name, rels)  # (obj_name, {to the left of, ...})
            rel_obj_map[key].add(obj_id)
    return dict(rel_obj_map)


def get_relations(obj_id, objects):
    subj = objects[obj_id]
    relations = list()
    for rel in subj['relations']:
        obj_name = objects[rel['object']]['name']
        rel_str = rel['name'] + ' ' + obj_name
        relations.append(rel_str)
    return relations


def param_grid(token_dict):
    for p in token_dict:
        items = sorted(p.items())
        if not items:
            yield {}
        else:
            keys, values = zip(*items)
            for v in product(*values):
                params = dict(zip(keys, v))
                yield params


def get_assignments(token_dict, max_depth=1):
    for token in token_dict:
        if token.startswith('attrs'):
            token_dict[token] = list(map(lambda x: ' '.join(x), powerset(token_dict[token], max_depth)))
    return list(param_grid([token_dict]))


def get_color(obj):
    attributes = set(obj['attributes'])
    colors = attributes.intersection(train_colors.keys())
    return frozenset(colors)


def get_random_color(obj_name, exclude=None):
    exclude = frozenset() if exclude is None else exclude
    assert type(exclude) in [frozenset, set], AssertionError(f'exclude argument passed not frozenset.')
    color_dict = train_colors if obj_name not in train_obj_colors else train_obj_colors[obj_name]
    if len(color_dict) < 2:
        color_dict = train_colors
    filter_colors = {k: color_dict[k] for k in filter(lambda x: x not in exclude, color_dict)}

    # ranked - get first/best
    # color = list(filter_colors.keys())[0]

    # probabilistic
    colors, probs = zip(*filter_colors.items())
    probs = np.divide(probs, sum(probs))
    color = np.random.choice(colors, size=1, p=probs)[0]
    return (color, )


def get_random_attrs(obj_name):
    if obj_name:
        if obj_name in object_attribute_distributions:
            attr_names, attr_probs = zip(*object_attribute_distributions[obj_name].items())
            attrs = tuple(np.random.choice(attr_names, 1, p=attr_probs))
        else:
            attrs = random.sample(attribute_vocabulary, k=1)
    else:
        num_attrs = np.random.choice([1, 2, 3], p=[0.7, 0.2, 0.1])
        attrs = random.sample(attribute_vocabulary, k=num_attrs)
    return tuple(filter(lambda x: x, attrs))  # equivalent to tuple(filter(lambda x: x != '', attrs))
    # NOTE, if attrs is ('', ), then return the empty tuple.


def get_attribute_with_prob(attributes):
    if len(attributes) <= 1:
        return tuple(attributes)
    else:
        attrs, probs = zip(*list(map(lambda x: (x, attribute_distribution[x]), attributes)))  # zip aligns properly
        probs = np.array(probs) / sum(probs)
        return tuple(filter(lambda x: x, np.random.choice(attrs, 1, p=probs)))


def get_random_obj_rel_pairs_list(exclude_object_names, rel_obj, subset):
    subset = set(train_object_counts.keys()) if subset is None else subset
    all_objects = {s for l in [get_inferential_hypernyms(obj) for obj in exclude_object_names] for s in l}
    all_objects |= {s for l in [get_hyponyms(obj) for obj in exclude_object_names] for s in l}
    all_objects |= abstract_objects
    if rel_obj not in train_object_object_relation_pairs:
        return list(), list()
    cond_obj = rel_obj
    co_obj_rel_pair_dist = list(filter(lambda x: x[0][0] not in all_objects, train_object_object_relation_pairs[cond_obj].items()))
    co_obj_rel_pair_dist = list(filter(lambda x: x[0][0] in subset, co_obj_rel_pair_dist))
    if co_obj_rel_pair_dist:
        obj_rel_pairs, probs = zip(*co_obj_rel_pair_dist)
        total = sum(probs)
        probs = list(map(lambda x: x / total, probs))
        return obj_rel_pairs, probs
    else:
        return list(), list()


def get_random_obj_names_list(exclude_object_names, from_subset=None):
    from_subset = set(object_object_distributions.keys()) if from_subset is None else from_subset
    total_dist = defaultdict(float)
    all_objects = {s for l in [get_inferential_hypernyms(obj) for obj in exclude_object_names] for s in l}
    all_objects |= {s for l in [get_hyponyms(obj) for obj in exclude_object_names] for s in l}
    co_objs = {s for l in [object_object_distributions[obj_name] for
                           obj_name in set(exclude_object_names).intersection(object_object_distributions.keys())]
               for s in l}
    co_objs = co_objs.intersection(from_subset)
    co_objs = co_objs.difference(all_objects)
    co_objs = co_objs.difference(abstract_objects)
    for co_obj in co_objs:
        # total_dist[co_obj] += np.log(object_distribution[co_obj])  # prior dist
        default_prob = np.log(object_default_probability[co_obj])
        cond_probs = object_object_distributions[co_obj]
        for obj_name in exclude_object_names:
            if obj_name in cond_probs:
                total_dist[co_obj] += np.log(cond_probs[obj_name])
            else:
                total_dist[co_obj] += default_prob
    if total_dist:
        obj_names, obj_probs = zip(*total_dist.items())
        obj_probs = np.exp(obj_probs)
        obj_probs = np.divide(obj_probs, sum(obj_probs))
        return obj_names, obj_probs
    else:
        return list(), list()


def get_random_object_rel(exclude_object_names, k, rel_object, subset=None):
    obj_rel_pairs, probs = get_random_obj_rel_pairs_list(exclude_object_names, rel_object, subset)
    if len(obj_rel_pairs) > 0:
        if k:
            return [obj_rel_pairs[x] for x in list(np.random.choice(len(obj_rel_pairs), p=probs, replace=False, size=min(k, len(obj_rel_pairs))))]
        else:
            return obj_rel_pairs[np.random.choice(len(obj_rel_pairs), p=probs)]
    else:
        return None


def get_random_obj_name(exclude_object_names, k=None, from_subset=None):
    obj_names, obj_probs = get_random_obj_names_list(exclude_object_names, from_subset)
    if len(obj_names) > 0:
        if k:
            return list(np.random.choice(obj_names, p=obj_probs, replace=False, size=min(k, len(obj_names))))
        else:
            return np.random.choice(obj_names, p=obj_probs)
    else:
        if k:
            return list()
        else:
            return None


# def get_random_obj_name(exclude=None):
#     if not exclude:
#         exclude = set()
#     elif type(exclude) == str:
#         exclude = {exclude, }
#     obj_names, obj_probs = zip(*{k: object_distribution[k] for k in
#                                  set(object_distribution.keys()).difference(exclude)}.items())
#     obj_probs = np.divide(obj_probs, sum(obj_probs))
#     obj = np.random.choice(obj_names, p=obj_probs)
#     return obj

## From checklist/text_generation.py #########################
# def all_possible_hypernyms(word, pos=None, depth=None):
#     ret = []
#     for syn in all_synsets(word, pos=pos):
#         ret.extend([y for x in syn.hypernyms(recursive=True, depth=depth) for y in x.senses])
#     return clean_senses(ret)
#
# def all_synsets(word, pos=None):
#     map = {
#         'NOUN': wordnet.NOUN,
#         'VERB': wordnet.VERB,
#         'ADJ': wordnet.ADJECTIVE,
#         'ADV': wordnet.ADVERB
#         }
#     if pos is None:
#         pos_list = [wordnet.VERB, wordnet.ADJECTIVE, wordnet.NOUN, wordnet.ADVERB]
#     else:
#         pos_list = [map[pos]]
#     ret = []
#     for pos in pos_list:
#         ret.extend(wordnet.synsets(word, pos=pos))
#     return ret
#
#
# def clean_senses(synsets):
#     return [x for x in set(synsets) if '_' not in x]
#
#
# def all_possible_synonyms(word, pos=None):
#     ret = []
#     for syn in all_synsets(word, pos=pos):
#         # if syn.synonyms[0] != word:
#         #     continue
#         ret.extend(syn.senses)
#     return clean_senses(ret)
