import argparse
import json
import logging
import random
import re
import traceback
from collections import defaultdict, Counter
from copy import deepcopy
from math import inf
from pathlib import Path
from textwrap import wrap

from tqdm import tqdm

import script_utils
from engine import utils as engine_utils
from generate_precursor_data import QuestionFamily, generate_precursor_data


def contains_valid_ontology_assignment(question):
    assignment = question['assignment']
    object_names = get_object_names_from_assignment(assignment)
    return object_names.intersection(engine_utils.minimum_ontology_vocabulary) != set()


def contains_hyponym_assignment(question):
    assignment = question['assignment']
    object_names = get_object_names_from_assignment(assignment)
    return object_names.intersection(engine_utils.minimum_hypernyms) != set()  # keys of minimum_hypernyms are hyponyms


def contains_hypernym_assignment(question):
    assignment = question['assignment']
    object_names = get_object_names_from_assignment(assignment)
    return object_names.intersection(engine_utils.minimum_hyponyms) != set()  # keys of minimum_hyponyms are hypernyms


def get_valid_objects(scene, max_area=0.25, min_width=32, min_height=32):
    """
    returns list of objects from scene value, where the object's bbox <= max_area * scene_area and
    height and width are >= min_width and min_height.
    """
    width = scene['width']
    height = scene['height']
    total_area = width * height
    object_ids = list()
    for obj_id, obj in scene['objects'].items():
        obj_area = obj['h'] * obj['w']
        if obj_area <= max_area * total_area and \
                min_width <= obj['w'] <= width and \
                min_height <= obj['h'] <= height:
            object_ids.append(obj_id)
    return object_ids


def get_object_id_tokens(assignment):
    pattern = re.compile(engine_utils.TOKEN_ID)
    return set(filter(lambda x: pattern.match(x) and x.startswith('obj'), assignment.keys()))


def get_object_token_ids(assignment):
    id_tokens = get_object_id_tokens(assignment)
    return {assignment[key] for key in id_tokens}


def contains_none_assignments(assignment):
    object_token_ids = get_object_id_tokens(assignment)
    return any(list(map(lambda x: assignment.get(x) is None, object_token_ids)))


def get_object_names_from_assignment(assignment):
    pattern = re.compile(engine_utils.TOKEN)
    object_tokens = set(filter(lambda x: pattern.match(x) and x.startswith('obj'), assignment.keys()))
    return {assignment[key] for key in object_tokens}


def get_phony_assignment(assignment, scene, object_ids, token_group_numbers):
    new_assignment = deepcopy(assignment)
    triggered = False
    valid_object_ids = set(get_valid_objects(scene)) - set(object_ids)
    for number in token_group_numbers:
        if new_assignment[f'obj{number}_id'] is None and len(valid_object_ids) > 0:
            new_object_id = valid_object_ids.pop()
            triggered = True
            new_assignment.update({f'obj{number}_id': new_object_id,
                                   #                                    f'obj{number}': scene['objects'][new_object_id]['name'],
                                   #                                    f'attrs{number}': []
                                   f'obj{number}_id_is_phony': True})
    if triggered:
        return new_assignment
    else:
        return None


def get_hypernymical_assignment(assignment, generator, token_group_numbers):
    """
    hypernymical assignments preserve answers for verification questions only when the answer is
    positive. i.e. if there is a brown puppy in the photo ==> there is a brown animal.
    """
    new_assignment = deepcopy(assignment)
    triggered = False
    for number in token_group_numbers:
        word = new_assignment[f'obj{number}']
        if word not in engine_utils.minimum_hypernyms:
            continue
        local_hypernyms = set(engine_utils.minimum_hypernyms[word])
        local_hypernyms -= get_object_names_from_assignment(new_assignment)
        if local_hypernyms:
            hypernym = random.choice(list(local_hypernyms))
            if generator.generator._is_plural(word):
                intersection = set(generator.generator._get_inflections(hypernym)['NNS'])
                intersection &= engine_utils.object_vocabulary
                if len(intersection) > 0:
                    hypernym = random.choice(list(intersection))
            new_assignment[f'obj{number}'] = hypernym
            triggered = True
    return new_assignment if triggered else None


def get_hyponymical_assignment(assignment, generator, token_group_numbers):
    """
    hyponymical assignments preserve answers for verification questions only when the answer is
    negative. i.e. if there is no white food in the photo ==> there is no white cake.
    """
    new_assignment = deepcopy(assignment)
    triggered = False
    for number in token_group_numbers:
        word = new_assignment[f'obj{number}']
        if word not in engine_utils.minimum_hyponyms:
            continue
        local_hyponyms = set(engine_utils.minimum_hyponyms[word])
        local_hyponyms -= get_object_names_from_assignment(new_assignment)
        if local_hyponyms:
            hyponym = random.choice(list(local_hyponyms))
            if generator.generator._is_plural(word):
                intersection = set(generator.generator._get_inflections(hyponym)['NNS'])
                intersection &= engine_utils.object_vocabulary
                if len(intersection) > 0:
                    hyponym = random.choice(list(intersection))
            new_assignment[f'obj{number}'] = hyponym
            triggered = True
    return new_assignment if triggered else None


def get_negatively_sampled_assignment(assignment, scene_objects, token_group_numbers):
    new_tokens = dict()
    for number in token_group_numbers:
        word = assignment[f'obj{number}']
        relations = list(filter(lambda x: x.endswith(f'rel{number}'), assignment.keys()))
        if relations:
            rel_object = assignment[f'obj{re.match(engine_utils.RELATION_TOKEN, relations[0]).group(1)}']
            object_rel_pairs = engine_utils.get_random_object_rel(scene_objects, k=1, rel_object=rel_object)
            if object_rel_pairs is None:
                return None
            context_object, context_rel = object_rel_pairs[0]
        else:
            context_objects = engine_utils.get_random_obj_name(scene_objects, 1)
            if context_objects is None:
                return None
            context_object = context_objects[0]
        if f'attrs{number}' in assignment:
            attrs = engine_utils.get_random_attrs(context_object)
            new_tokens[f'attrs{number}'] = attrs
        new_tokens[f'obj{number}'] = context_object
        scene_objects.add(context_object)
        if f'obj{number}_id' in assignment:
            new_tokens[f'obj{number}_id'] = None
    new_assignment = deepcopy(assignment)
    new_assignment.update(new_tokens)
    return new_assignment


def get_negatively_sampled_assignment_from_subset(assignment, scene_objects, token_group_numbers, subset):
    new_tokens = dict()
    for number in token_group_numbers:
        word = assignment[f'obj{number}']
        context_objects = engine_utils.get_random_obj_name(scene_objects, 1, subset)
        if len(context_objects) == 0:
            return None
        else:
            context_object = context_objects[0]
        if f'attrs{number}' in assignment:
            attrs = engine_utils.get_random_attrs(context_object)
            new_tokens[f'attrs{number}'] = attrs
        new_tokens[f'obj{number}'] = context_object
        scene_objects.add(context_object)
        if f'obj{number}_id' in assignment:
            new_tokens[f'obj{number}_id'] = None
    new_assignment = deepcopy(assignment)
    new_assignment.update(new_tokens)
    return new_assignment


def get_swapped_assignment(assignment, token_group_1, token_group_2):
    token_group_1 = token_group_1 if type(token_group_1) is int else int(token_group_1)
    token_group_2 = token_group_2 if type(token_group_2) is int else int(token_group_2)
    new_assignment = deepcopy(assignment)
    swapped_tokens = dict()
    group_1 = set()
    group_2 = set()
    for token in assignment:
        match = re.match(engine_utils.TOKEN, token)
        if match:
            if int(match.group(4)) == token_group_1:
                group_1.add(token)
            elif int(match.group(4)) == token_group_2:
                group_2.add(token)
    for token in group_1:
        match = re.match(engine_utils.TOKEN, token)
        swapped_key = f'{match.group(3)}{token_group_2}'
        assert swapped_key in group_2
        swapped_tokens[swapped_key] = assignment[token]
    for token in group_2:
        match = re.match(engine_utils.TOKEN, token)
        swapped_key = f'{match.group(3)}{token_group_1}'
        assert swapped_key in group_1
        swapped_tokens[swapped_key] = assignment[token]
    id_key_1 = f'obj{token_group_1}_id'
    id_key_2 = f'obj{token_group_2}_id'
    swapped_tokens[id_key_1] = assignment[id_key_2]
    swapped_tokens[id_key_2] = assignment[id_key_1]
    new_assignment.update(swapped_tokens)
    return new_assignment


def swap_choices_order(assignment):
    keys = list(filter(lambda x: 'category-options' in x, assignment.keys()))
    assert len(keys) == 1
    key = keys[0]
    choices = assignment[key]
    new_choices = deepcopy(choices)
    while new_choices == choices and len(choices) > 1:
        new_choices = tuple(random.sample(new_choices, k=len(new_choices)))
    assert new_choices != choices or len(choices) <= 1
    new_assignment = deepcopy(assignment)
    new_assignment[key] = new_choices
    return new_assignment



def generate_antonym_questions(questions, scenes, generator):
    antonym_questions = list()
    for question in questions:
        img_id = question['img_id']
        scene = scenes[img_id]
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        label = question['label']
        new_assignment = deepcopy(question['assignment'])
        attrs = new_assignment['attrs1']
        assert len(attrs) == 1 and attrs[0] in engine_utils.train_attribute_antonyms
        antonym = engine_utils.train_attribute_antonyms[attrs[0]]
        obj1_id = new_assignment['obj1_id']
        scene_attrs = scene['objects'][obj1_id]['attributes']
        if antonym in scene_attrs:
            continue
        new_assignment['attrs1'] = (antonym,)
        new_label = generator.generator.handler.get_answer(scene, generator.program, new_assignment)
        assert new_label not in label, AssertionError('Expected behavior is for the answer to change.')
        new_sent = generator.generator.expand_text_template_multi(template, new_assignment)
        question_dict = {
            'img_id': img_id,
            'label': {new_label: 1.0},
            'question_id': f'{q_id}-a',
            'sent': new_sent,
            'question_type': generator.question_type,
            'assignment': new_assignment
        }
        antonym_questions.append(question_dict)
    return antonym_questions


def generate_ontology_questions(questions, generator, token_group_numbers) -> list:
    ontology_questions = list()
    img_id_2_sent = defaultdict(set)
    for question in questions:
        img_id = question['img_id']
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        label = question['label']
        if 'no' in label:
            new_assignment = get_hyponymical_assignment(assignment, generator, token_group_numbers)
        else:  # yes or non-binary question
            new_assignment = get_hypernymical_assignment(assignment, generator, token_group_numbers)
            # raise ValueError('Received question with labels %s. Excepted labels with `yes` or `no`' % repr(label))
        if new_assignment:
            sent = generator.generator.expand_text_template_multi(template, new_assignment)
            if sent in img_id_2_sent[img_id]:
                continue
            img_id_2_sent[img_id].add(sent)
            question_dict = {
                'img_id': img_id,
                'label': label,
                'question_id': f'{q_id}-h',
                'sent': sent,
                'question_type': generator.question_type,
                'assignment': new_assignment
            }
            ontology_questions.append(question_dict)
    return ontology_questions


def generate_hypernym_questions(questions, generator, token_group_numbers) -> object:
    hypernym_questions = list()
    img_id_2_sent = defaultdict(set)
    for question in questions:
        img_id = question['img_id']
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        label = question['label']
        assert 'yes' in label
        new_assignment = get_hypernymical_assignment(assignment, generator, token_group_numbers)
        if new_assignment:
            sent = generator.generator.expand_text_template_multi(template, new_assignment)
            if sent in img_id_2_sent[img_id]:
                continue
            img_id_2_sent[img_id].add(sent)
            question_dict = {
                'img_id': img_id,
                'label': label,
                'question_id': f'{q_id}-h',
                'sent': sent,
                'question_type': generator.question_type,
                'assignment': new_assignment
            }
            hypernym_questions.append(question_dict)
    return hypernym_questions


def generate_hyponym_questions(questions, generator, token_group_numbers) -> object:
    hyponym_questions = list()
    img_id_2_sent = defaultdict(set)
    for question in questions:
        img_id = question['img_id']
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        label = question['label']
        assert 'no' in label
        new_assignment = get_hyponymical_assignment(assignment, generator, token_group_numbers)
        if new_assignment:
            sent = generator.generator.expand_text_template_multi(template, new_assignment)
            if sent in img_id_2_sent[img_id]:
                continue
            img_id_2_sent[img_id].add(sent)
            question_dict = {
                'img_id': img_id,
                'label': label,
                'question_id': f'{q_id}-h',
                'sent': sent,
                'question_type': generator.question_type,
                'assignment': new_assignment
            }
            hyponym_questions.append(question_dict)
    return hyponym_questions


def generate_negative_sampling(
        base_questions,
        generator,
        scenes,
        token_group_numbers,
        sampling_func=get_negatively_sampled_assignment,
        **kwargs):
    new_questions = list()
    img_id_2_sent = defaultdict(set)
    for question in tqdm(base_questions):
        img_id = question['img_id']
        scene = scenes[img_id]
        scene_objects = {obj['name'] for obj in scene['objects'].values()}
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        new_assignment = sampling_func(assignment, scene_objects, token_group_numbers, **kwargs)
        if new_assignment is not None:
            sent = generator.generator.expand_text_template_multi(template, new_assignment)
            if sent in img_id_2_sent[img_id]:
                continue
            img_id_2_sent[img_id].add(sent)
            answer = generator.generator.handler.get_answer(scene, generator.program, new_assignment)
            question_dict = {
                'img_id': img_id,
                'label': {answer: 1.0},
                'question_id': q_id,
                'sent': sent,
                'question_type': generator.question_type,
                'assignment': new_assignment
            }
            new_questions.append(question_dict)
    return new_questions


def generate_perturbation_data(questions, scenes, token_group_numbers):
    perturbation_questions = list()
    for question in questions:
        scene = scenes[question['img_id']]
        valid_objects = get_valid_objects(scene)
        assignment = question['assignment']
        assignment_object_ids = get_object_token_ids(assignment).difference({None})
        question_dict = deepcopy(question)
        if contains_none_assignments(assignment):
            new_assignment = get_phony_assignment(assignment, scene, assignment_object_ids, token_group_numbers)
            if new_assignment is not None and not contains_none_assignments(new_assignment):
                question_dict['assignment'] = new_assignment
            else:
                continue
        if not get_object_token_ids(question_dict['assignment']).issubset(valid_objects):
            # also reject all assignments whith non-compliant objects (too big, too small, etc.)
            continue
        perturbation_questions.append(question_dict)
    return perturbation_questions


def generate_template_data(questions, generator):
    new_questions = list()
    for question in questions:
        label = question['label']
        new_assignment = deepcopy(question['assignment'])
        q_id = question['question_id']
        img_id, a_ix, t_ix = q_id.split('-')
        new_t_ix = random.sample(set(range(len(generator.templates))) - {int(t_ix)}, k=1)[0]
        template = generator.templates[new_t_ix]
        sent = generator.generator.expand_text_template_multi(template, new_assignment)
        question_dict = {
            'img_id': img_id,
            'label': label,
            'question_id': f'{q_id}-t',
            'sent': sent,
            'question_type': generator.question_type,
            'assignment': new_assignment,
            'new_template_number': new_t_ix,
        }
        new_questions.append(question_dict)
    return new_questions


def generate_negation_data(questions, generator, negation_generator, scenes):
    new_questions = list()
    assert len(generator.templates) == len(negation_generator.templates)
    for question in questions:
        label = question['label']
        q_id = question['question_id']
        assignment = question['assignment']
        img_id, a_ix, t_ix = q_id.split('-')
        scene = scenes[img_id]
        template = negation_generator.templates[int(t_ix)]
        sent = generator.generator.expand_text_template_multi(template, assignment)
        new_label = negation_generator.generator.handler.get_answer(scene, negation_generator.program, assignment)
        assert new_label not in label, AssertionError('Expected behavior is for the answer to change.')
        question_dict = {
            'img_id': img_id,
            'label': {new_label: 1.0},
            'question_id': f'{q_id}-n',
            'sent': sent,
            'question_type': negation_generator.question_type,
            'assignment': assignment,
        }
        new_questions.append(question_dict)
    return new_questions


def generate_symmetric_choices_data(questions, generator, scenes):
    symmetric_questions = list()
    for question in tqdm(questions):
        img_id = question['img_id']
        scene = scenes[img_id]
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        new_assignment = swap_choices_order(assignment)
        sent = generator.generator.expand_text_template_multi(template, new_assignment)
        old_sent = generator.generator.expand_text_template_multi(template, assignment)
        answer = generator.generator.handler.get_answer(scene, generator.program, new_assignment)
        assert answer in question['label']
        if sent == old_sent:
            x = 10
        assert sent != old_sent
        question_dict = {
            'img_id': img_id,
            'label': {answer: 1.0},
            'question_id': f'{q_id}-s',
            'sent': sent,
            'question_type': generator.question_type,
            'assignment': new_assignment
        }
        symmetric_questions.append(question_dict)
    return symmetric_questions


def generate_symmetric_data(questions, generator, scenes):
    symmetric_questions = list()
    for question in tqdm(questions):
        img_id = question['img_id']
        scene = scenes[img_id]
        q_id = question['question_id']
        t_ix = int(q_id.split('-')[-1])
        template = generator.templates[t_ix]
        assignment = question['assignment']
        new_assignment = get_swapped_assignment(assignment, 1, 2)
        sent = generator.generator.expand_text_template_multi(template, new_assignment)
        answer = generator.generator.handler.get_answer(scene, generator.program, new_assignment)
        assert answer in question['label']
        question_dict = {
            'img_id': img_id,
            'label': {answer: 1.0},
            'question_id': f'{q_id}-s',
            'sent': sent,
            'question_type': generator.question_type,
            'assignment': new_assignment
        }
        symmetric_questions.append(question_dict)
    return symmetric_questions


def _split_into_half(inlist):
    half = len(inlist) // 2
    first_half, second_half = inlist[:half], inlist[half:]
    return first_half, second_half


def _split_into_half_quarter_quarter(inlist):
    first_half, second_half = _split_into_half(inlist)
    quarter = len(second_half) // 2
    first_quarter, second_quarter = second_half[:quarter], second_half[quarter:]
    return first_half, first_quarter, second_quarter


def entropy1(labels, base=None):
    from scipy.stats import entropy
    import numpy as np
    value, counts = np.unique(labels, return_counts=True)
    return entropy(counts, base=base)


def get_label_func(question):
    return list(question['label'].keys())[0]


def get_2rel1_func(question):
    return question['assignment']['2rel1']


def _balance_dataset_on_func(questions, total_number, func):
    if len(questions) <= total_number:
        return questions
    answers = list(map(func, questions))
    counts = Counter(answers)
    weights = engine_utils.smoothed_weights(answers, counts, plus=round(sum(counts.values()) / len(counts)))  # plus chosen heuristically
    sublist = engine_utils.weighted_shuffle(questions, weights)[:total_number]
    return sublist


class DatasetsProgram:
    def __init__(self, scene_graph_file, output_dir):
        # returns names of own functions
        self._output_dir = output_dir
        self._scene_graph_file = scene_graph_file
        self.valid_datasets = set(filter(lambda x: not x.startswith('_'), dir(self)))
        logging.info(f'Found valid datasets: {self.valid_datasets}')

    def _save_data(self, questions, datset_name, suffix):
        filename = Path(self._output_dir, f'{datset_name}_{suffix}.json')
        logging.info(f'\tSaving {len(questions)} questions to {filename.as_posix()} .')
        logging.info(f'\twith {Counter([list(k["label"].keys())[0] for k in questions]).most_common()}.')
        random.shuffle(questions)
        json.dump(questions, filename.open('w+'))

    def _safe_save(self, questions, datset_name, suffix):
        filename = Path(self._output_dir, f'{datset_name}_{suffix}.json')
        filename = script_utils.get_numbered_path(filename)
        logging.info(f'\tSaving {len(questions)} questions to {filename.as_posix()} .')
        logging.info(f'\twith {Counter([list(k["label"].keys())[0] for k in questions]).most_common()}.')
        random.shuffle(questions)
        json.dump(questions, filename.open('w+'))

    def attribute_verification(self):
        question_family_file = Path('question_families/attribute_verification.json')
        self._attribute_verification_subgenerator('attribute_verification', question_family_file, 2500)

    def attribute_verification_rel(self):
        question_family_file = Path('question_families/attribute_verification_rel1.json')
        self._attribute_verification_subgenerator('attribute_verification_rel', question_family_file, 2500, balance_rel=True)

    def _attribute_verification_subgenerator(self, dataset_base_name, question_family_file, cutoff, balance_rel=False):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        base_questions = list(filter(lambda x: len(x['assignment']['attrs1']) == 1 and x['assignment']['attrs1'][0] in
                                               engine_utils.train_attribute_antonyms, base_questions))

        random.shuffle(base_questions)
        antonym_questions = generate_antonym_questions(base_questions, scenes, generator)
        if balance_rel:
            antonym_questions = _balance_dataset_on_func(antonym_questions, total_number=cutoff, func=get_2rel1_func)
        else:
            antonym_questions = random.sample(antonym_questions, k=min(cutoff, len(antonym_questions)))  # cutoff antonyms
        base_q_ids = [q['question_id'].rsplit('-', 1)[0] for q in antonym_questions]
        base_questions = [q for q in base_questions if q['question_id'] in base_q_ids]
        self._save_data(base_questions, dataset_base_name, 'base_questions')
        antonym_questions += base_questions
        self._save_data(antonym_questions, dataset_base_name, 'antonym_test')

    def object_verification_rel(self):
        question_family_file = Path('question_families/object_verification_1_attrs_rel1.json')
        negation_question_family_file = Path('question_families/object_verification_1_attrs_rel1_negation.json')
        self._object_verification_subgenerator('object_verification_rel', question_family_file, negation_question_family_file, 2500, balance_rel=True)

    def object_verification(self):
        question_family_file = Path('question_families/object_verification_1_attrs.json')
        negation_question_family_file = Path('question_families/object_verification_1_attrs_negation.json')
        self._object_verification_subgenerator('object_verification', question_family_file, negation_question_family_file, 2500)

    def _object_verification_subgenerator(self, dataset_base_name, question_family_file, negation_question_family_file, cutoff, balance_rel=False):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)
        positive_questions, negative_questions = _split_into_half(base_questions)
        negative_questions = generate_negative_sampling(negative_questions, generator, scenes, [1])
        base_questions = positive_questions + negative_questions
        if balance_rel:
            base_questions = _balance_dataset_on_func(base_questions, total_number=cutoff, func=get_2rel1_func)
        else:
            base_questions = random.sample(base_questions, k=min(cutoff, len(base_questions)))
        self._save_data(base_questions, dataset_base_name, 'base_questions')

        template_questions = base_questions + generate_template_data(base_questions, generator)
        self._save_data(template_questions, dataset_base_name, 'template_test')

        negation_question_family = json.load(negation_question_family_file.open())
        negation_generator = QuestionFamily(scenes, negation_question_family)
        negation_questions = generate_negation_data(base_questions, generator, negation_generator, scenes)
        negation_questions += base_questions
        self._save_data(negation_questions, dataset_base_name, 'negation_test')

    def attribute_choice_rel(self):
        question_family_file = Path('question_families/attribute_rel1_choice.json')
        self._choice_subgenerator('attribute_choice_rel', question_family_file, 2250, balance_rel=True)

    def attribute_choice(self):
        question_family_file = Path('question_families/attribute_choice.json')
        self._choice_subgenerator('attribute_choice', question_family_file, 2250)

    def action_choice(self):
        question_family_file = Path('question_families/action_choice.json')
        self._choice_subgenerator('action_choice', question_family_file, 250)

    def action_choice_rel(self):
        question_family_file = Path('question_families/action_rel1_choice.json')
        self._choice_subgenerator('action_choice_rel', question_family_file, 250, balance_rel=True)

    def object_choice_rel(self):
        question_family_file = Path('question_families/object_rel1_choice.json')
        self._choice_subgenerator('object_choice_rel', question_family_file, 2500, balance_rel=True)

    # def object_choice_rel_strict(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_rel1_choice_strict_counts.json')
    #     self._choice_subgenerator(dataset_base_name, question_family_file, 2500, balance_rel=True)

    def object_choice_by_attr(self):
        question_family_file = Path('question_families/object_by_attr_choice.json')
        self._choice_subgenerator('object_choice_by_attr', question_family_file, 2500)

    def _choice_subgenerator(self, dataset_base_name, question_family_file, cutoff, balance_rel=False):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        scenes = json.load(self._scene_graph_file.open())
        question_family = json.load(question_family_file.open())
        generator = QuestionFamily(scenes, question_family)
        if balance_rel:
            base_questions = _balance_dataset_on_func(base_questions, total_number=len(base_questions) // 2, func=get_2rel1_func)
        base_questions = _balance_dataset_on_func(base_questions, total_number=cutoff, func=get_label_func)
        random.shuffle(base_questions)
        random.shuffle(base_questions)
        self._save_data(base_questions, dataset_base_name, 'base_questions')

        symmetric_questions_1 = generate_symmetric_choices_data(base_questions, generator, scenes)
        symmetric_questions = base_questions + symmetric_questions_1
        random.shuffle(symmetric_questions)
        self._save_data(symmetric_questions, dataset_base_name, 'symmetric_test')

        template_questions_1 = generate_template_data(base_questions, generator)
        template_questions = base_questions + template_questions_1
        random.shuffle(template_questions)
        self._save_data(template_questions, dataset_base_name, 'template_test')

    def object_choice_rel_ontology(self):
        question_family_file = Path('question_families/object_rel1_choice.json')
        self._choice_ontology_subgenerator('object_choice_rel_ontology', question_family_file, cutoff=2500, balance_rel=True)

    def attribute_choice_rel_ontology(self):
        question_family_file = Path('question_families/attribute_rel1_choice.json')
        self._choice_ontology_subgenerator('attribute_choice_rel_ontology', question_family_file, 2250, balance_rel=True)

    def action_choice_rel_ontology(self):
        question_family_file = Path('question_families/action_rel1_choice.json')
        self._choice_ontology_subgenerator('action_choice_rel_ontology', question_family_file, 250, balance_rel=True)

    def _choice_ontology_subgenerator(self, dataset_base_name, question_family_file, cutoff, balance_rel):
        from engine.constraint_handler import ConstraintHandler
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True,
        )
        scenes = json.load(self._scene_graph_file.open())
        question_family = json.load(question_family_file.open())
        generator = QuestionFamily(scenes, question_family)
        if balance_rel:
            base_questions = _balance_dataset_on_func(base_questions, total_number=len(base_questions)//2, func=get_2rel1_func)
        base_questions = _balance_dataset_on_func(base_questions, total_number=len(base_questions) // 2, func=get_label_func)
        random.shuffle(base_questions)
        ontology_questions = generate_ontology_questions(base_questions, generator, [2])
        if dataset_base_name.endswith('_rel'):
            ontology_questions = list(filter(lambda x: ConstraintHandler(scenes[x['img_id']]['objects']).check_constraints(generator.constraints, x['assignment']), ontology_questions))
        cutoff = min(cutoff, len(ontology_questions))
        ontology_questions = ontology_questions[:cutoff]
        valid_base_question_ids = set([q['question_id'].rsplit('-', 1)[0] for q in ontology_questions])
        base_questions = list(filter(lambda x: x['question_id'] in valid_base_question_ids, base_questions))
        self._save_data(base_questions, dataset_base_name, 'base_questions')
        ontology_questions = base_questions + ontology_questions
        self._save_data(ontology_questions, dataset_base_name, 'ontology_test')

    # Disable to prevent new dataset being generated.
    def object_choice_rel_perturbation(self):
        question_family_file = Path('question_families/object_rel1_choice.json')
        self._choice_perturbation_subgenerator('object_choice_rel_perturbation', question_family_file, 2500, balance_rel=False)

    def object_choice_by_attr_perturbation(self):
        question_family_file = Path('question_families/object_by_attr_choice.json')
        self._choice_perturbation_subgenerator('object_choice_by_attr_perturbation', question_family_file, 2500, balance_rel=False)

    def attribute_choice_perturbation(self):
        question_family_file = Path('question_families/attribute_choice.json')
        self._choice_perturbation_subgenerator('attribute_choice_perturbation', question_family_file, 2250, balance_rel=False)

    def action_choice_rel_perturbation(self):
        question_family_file = Path('question_families/action_rel1_choice.json')
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=inf,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        random.shuffle(base_questions)
        self._choice_perturbation_subgenerator('action_choice_rel_perturbation', question_family_file, 250, balance_rel=False)

    def action_choice_perturbation(self):
        question_family_file = Path('question_families/action_choice.json')
        self._choice_perturbation_subgenerator('action_choice_perturbation', question_family_file, 250, balance_rel=False)

    def attribute_choice_rel_perturbation(self):
        question_family_file = Path('question_families/attribute_rel1_choice.json')
        self._choice_perturbation_subgenerator('attribute_choice_rel_perturbation', question_family_file, 2250, balance_rel=True)

    def _choice_perturbation_subgenerator(self, dataset_base_name, question_family_file, cutoff, balance_rel=False):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=inf,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        scenes = json.load(self._scene_graph_file.open())
        question_family = json.load(question_family_file.open())
        if balance_rel:
            base_questions = _balance_dataset_on_func(base_questions, total_number=len(base_questions) // 2, func=get_2rel1_func)
        base_questions = _balance_dataset_on_func(base_questions, total_number=cutoff, func=get_label_func)
        random.shuffle(base_questions)
        random.shuffle(base_questions)
        self._save_data(base_questions, dataset_base_name, 'base_questions')

    # Disable to prevent new dataset being generated.
    def object_verification_perturbation(self):
        question_family_file = Path('question_families/object_verification_1_attrs.json')
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=inf,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)
        positive_questions, negative_questions = _split_into_half(base_questions)
        negative_questions = generate_negative_sampling(negative_questions, generator, scenes, [1])
        base_questions = positive_questions + negative_questions
        perturbation_questions = generate_perturbation_data(base_questions, scenes, [1])
        perturbation_questions = script_utils.get_balanced_questions(perturbation_questions)
        # perturbation_questions = random.sample(perturbation_questions, k=min(24000, len(perturbation_questions)))
        self._safe_save(perturbation_questions, 'object_verification_perturbation', 'base_questions')

    def object_verification_ontology(self):
        # raise NotImplementedError('Ontological variants not implemented yet.')
        question_family_file = Path('question_families/object_verification_1_attrs.json')
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)
        positive_questions, positive_questions_1, negative_questions = _split_into_half_quarter_quarter(base_questions)
        positive_questions += positive_questions_1
        negative_questions = generate_negative_sampling(negative_questions, generator, scenes, [1],
                                                        sampling_func=get_negatively_sampled_assignment_from_subset,
                                                        subset=engine_utils.minimum_hyponyms)
        base_questions = positive_questions + negative_questions
        base_questions = list(filter(contains_valid_ontology_assignment, base_questions))
        ontology_questions = generate_ontology_questions(base_questions, generator, [1])
        ontology_questions = random.sample(ontology_questions, k=min(len(ontology_questions), 5000))
        valid_base_question_ids = set([q['question_id'].rsplit('-', 1)[0] for q in ontology_questions])
        base_questions = list(filter(lambda x: x['question_id'] in valid_base_question_ids, base_questions))
        ontology_questions = base_questions + ontology_questions
        self._save_data(base_questions, 'object_verification_ontology', 'base_questions')
        self._save_data(ontology_questions, 'object_verification_ontology', 'ontology_test')

    # def object_verification_hypernym(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_verification_1_attrs.json')
    #     base_questions = generate_precursor_data(
    #         question_family=question_family_file,
    #         scene_graphs=[self._scene_graph_file],
    #         seed=666,
    #         max_assignments_per_scene=4,
    #         max_templates_per_assignment=1,
    #         negative_sampling_method=None,
    #         balance_answers=False,
    #         include_ids=True
    #     )
    #     question_family = json.load(question_family_file.open())
    #     scenes = json.load(self._scene_graph_file.open())
    #     generator = QuestionFamily(scenes, question_family)
    #     base_questions = list(filter(contains_hyponym_assignment, base_questions))
    #     hypernym_questions = generate_hypernym_questions(base_questions, generator, [1])
    #     valid_base_question_ids = set([q['question_id'].rsplit('-', 1)[0] for q in hypernym_questions])
    #     base_questions = list(filter(lambda x: x['question_id'] in valid_base_question_ids, base_questions))
    #     test_questions = base_questions + hypernym_questions
    #     self._save_data(test_questions, dataset_base_name, 'test')
    #
    # def object_verification_hyponym(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_verification_1_attrs.json')
    #     base_questions = generate_precursor_data(
    #         question_family=question_family_file,
    #         scene_graphs=[self._scene_graph_file],
    #         seed=666,
    #         max_assignments_per_scene=1,
    #         max_templates_per_assignment=1,
    #         negative_sampling_method=None,
    #         balance_answers=False,
    #         include_ids=True
    #     )
    #     question_family = json.load(question_family_file.open())
    #     scenes = json.load(self._scene_graph_file.open())
    #     generator = QuestionFamily(scenes, question_family)
    #     base_questions = generate_negative_sampling(base_questions, generator, scenes, [1],
    #                                                 sampling_func=get_negatively_sampled_assignment_from_subset,
    #                                                 subset=engine_utils.minimum_hyponyms)
    #     base_questions = list(filter(contains_hypernym_assignment, base_questions))
    #     hyponym_questions = generate_hyponym_questions(base_questions, generator, [1])
    #     valid_base_question_ids = set([q['question_id'].rsplit('-', 1)[0] for q in hyponym_questions])
    #     base_questions = list(filter(lambda x: x['question_id'] in valid_base_question_ids, base_questions))
    #     test_questions = base_questions + hyponym_questions
    #     self._save_data(test_questions, dataset_base_name, 'test')

    def object_verification_conjunction(self):
        question_family_file = Path('question_families/object_verification_conjunction_1_attrs.json')
        negation_question_family_file = Path('question_families/object_verification_conjunction_1_attrs_negation.json')
        self._object_verification_conjunction_subgenerator('object_verification_conjunction', question_family_file, negation_question_family_file, 2500)

    # def object_verification_conjunction_rel(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_verification_conjunction_1_attrs_rel1_both.json')
    #     negation_question_family_file = Path('question_families/object_verification_conjunction_1_attrs_rel_1_both_negation.json')
    #     self._object_verification_conjunction_subgenerator(dataset_base_name, question_family_file, negation_question_family_file, 2500)

    def _object_verification_conjunction_subgenerator(self, dataset_base_name, question_family_file, negation_question_family_file, cutoff):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=3,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        positive_questions, negative_left, negative_right = _split_into_half_quarter_quarter(base_questions)
        negative_1, negative_2 = _split_into_half(negative_left)
        negative_questions = generate_negative_sampling(negative_1, generator, scenes, [1])
        negative_questions += generate_negative_sampling(negative_2, generator, scenes, [2])
        negative_questions += generate_negative_sampling(negative_right, generator, scenes, [1, 2])
        base_questions = positive_questions + negative_questions
        base_questions = random.sample(base_questions, k=min(cutoff, len(base_questions)))
        self._save_data(base_questions, dataset_base_name, 'base_questions')

        alternative_template_questions = base_questions + generate_template_data(base_questions, generator)
        self._save_data(alternative_template_questions, dataset_base_name, 'template_test')

        symmetric_questions = base_questions + generate_symmetric_data(base_questions, generator, scenes)
        self._save_data(symmetric_questions, dataset_base_name, 'symmetric_test')

        negation_question_family = json.load(negation_question_family_file.open())
        negation_generator = QuestionFamily(scenes, negation_question_family)
        negation_questions = generate_negation_data(base_questions, generator, negation_generator, scenes)
        negation_questions += base_questions
        self._save_data(negation_questions, dataset_base_name, 'negation_test')

    # def object_verification_conjunction_partial_negation(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_verification_conjunction_1_attrs.json')
    #     base_questions = generate_precursor_data(
    #         question_family=question_family_file,
    #         scene_graphs=[self._scene_graph_file],
    #         seed=666,
    #         max_assignments_per_scene=2,
    #         max_templates_per_assignment=1,
    #         negative_sampling_method=None,
    #         balance_answers=False,
    #         include_ids=True
    #     )
    #     question_family = json.load(question_family_file.open())
    #     scenes = json.load(self._scene_graph_file.open())
    #     generator = QuestionFamily(scenes, question_family)
    #     random.shuffle(base_questions)
    #     positive_questions, negative_questions = _split_into_half(base_questions)
    #     negative_questions = generate_negative_sampling(negative_questions, generator, scenes, [2])
    #     base_questions = positive_questions + negative_questions
    #     # base_questions = random.sample(base_questions, k=min(12000, len(base_questions)))
    #     self._save_data(base_questions, dataset_base_name, 'base_questions')
    #
    #     partial_negation_question_family_file = Path(
    #         'question_families/object_verification_conjunction_1_attrs_partial_negation.json')
    #     partial_negation_question_family = json.load(partial_negation_question_family_file.open())
    #     partial_negation_generator = QuestionFamily(scenes, partial_negation_question_family)
    #     partial_negation_questions = generate_negation_data(base_questions, generator, partial_negation_generator,
    #                                                         scenes)
    #     partial_negation_questions += base_questions
    #     self._save_data(partial_negation_questions, dataset_base_name, 'test')

    def object_verification_conjunction_perturbation(self):
        question_family_file = Path('question_families/object_verification_conjunction_1_attrs.json')
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=inf,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)
        positive_questions, negative_left, negative_right = _split_into_half_quarter_quarter(base_questions)
        negative_1, negative_2 = _split_into_half(negative_left)
        negative_questions = generate_negative_sampling(negative_1, generator, scenes, [1])
        negative_questions += generate_negative_sampling(negative_2, generator, scenes, [2])
        negative_questions += generate_negative_sampling(negative_right, generator, scenes, [1, 2])
        base_questions = positive_questions + negative_questions
        perturbation_questions = generate_perturbation_data(base_questions, scenes, [1, 2])
        perturbation_questions = script_utils.get_balanced_questions(perturbation_questions)
        # perturbation_questions = random.sample(perturbation_questions, k=min(24000, len(perturbation_questions)))
        self._safe_save(perturbation_questions, 'object_verification_conjunction_perturbation', 'base_questions')

    def object_verification_disjunction(self):
        question_family_file = Path('question_families/object_verification_disjunction_1_attrs.json')
        negation_question_family_file = Path('question_families/object_verification_disjunction_1_attrs_negation.json')
        self._object_verification_disjunction_subgenerator('object_verification_disjunction', question_family_file, negation_question_family_file, 2500)

    # def object_verification_disjunction_rel(self, dataset_base_name):
    #     question_family_file = Path('question_families/object_verification_disjunction_1_attrs_rel1_both.json')
    #     negation_question_family_file = Path('question_families/object_verification_disjunction_1_attrs_rel1_both_negation.json')
    #     self._object_verification_disjunction_subgenerator(dataset_base_name, question_family_file, negation_question_family_file, 2500)

    def _object_verification_disjunction_subgenerator(self, dataset_base_name, question_family_file, negation_question_family_file, cutoff):
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=2,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)

        negative_questions, positive_left, positive_questions = _split_into_half_quarter_quarter(base_questions)
        positive_1, positive_2 = _split_into_half(positive_left)
        positive_questions += generate_negative_sampling(positive_1, generator, scenes, [1])
        positive_questions += generate_negative_sampling(positive_2, generator, scenes, [2])
        assert set(map(lambda x: list(x['label'].keys())[0], positive_questions)) == {'yes'}
        negative_questions = generate_negative_sampling(negative_questions, generator, scenes, [1, 2])
        base_questions = positive_questions + negative_questions
        base_questions = random.sample(base_questions, k=min(cutoff, len(base_questions)))
        self._save_data(base_questions, dataset_base_name, 'base_questions')

        alternative_template_questions = base_questions + generate_template_data(base_questions, generator)
        self._save_data(alternative_template_questions, dataset_base_name, 'template_test')

        symmetric_questions = base_questions + generate_symmetric_data(base_questions, generator, scenes)
        self._save_data(symmetric_questions, dataset_base_name, 'symmetric_test')

        negation_question_family = json.load(negation_question_family_file.open())
        negation_generator = QuestionFamily(scenes, negation_question_family)
        negation_questions = generate_negation_data(base_questions, generator, negation_generator, scenes)
        negation_questions += base_questions
        self._save_data(negation_questions, dataset_base_name, 'negation_test')

    def object_verification_disjunction_perturbation(self):
        question_family_file = Path('question_families/object_verification_disjunction_1_attrs.json')
        base_questions = generate_precursor_data(
            question_family=question_family_file,
            scene_graphs=[self._scene_graph_file],
            seed=666,
            max_assignments_per_scene=inf,
            max_templates_per_assignment=1,
            negative_sampling_method=None,
            balance_answers=False,
            include_ids=True
        )
        question_family = json.load(question_family_file.open())
        scenes = json.load(self._scene_graph_file.open())
        generator = QuestionFamily(scenes, question_family)
        random.shuffle(base_questions)

        negative_questions, positive_left, positive_questions = _split_into_half_quarter_quarter(base_questions)
        positive_1, positive_2 = _split_into_half(positive_left)
        positive_questions += generate_negative_sampling(positive_1, generator, scenes, [1])
        positive_questions += generate_negative_sampling(positive_2, generator, scenes, [2])
        assert set(map(lambda x: list(x['label'].keys())[0], positive_questions)) == {'yes'}
        negative_questions += generate_negative_sampling(negative_questions, generator, scenes, [1, 2])
        base_questions = positive_questions + negative_questions
        perturbation_questions = generate_perturbation_data(base_questions, scenes, [1, 2])
        perturbation_questions = script_utils.get_balanced_questions(perturbation_questions)
        # perturbation_questions = random.sample(perturbation_questions, k=min(24000, len(perturbation_questions)))
        self._safe_save(perturbation_questions, 'object_verification_disjunction_perturbation', 'base_questions')

    def _run(self, func):
        try:
            logging.info('Starting generation for dataset: %s' % func)
            getattr(self, func)()
            return [1]
        except:
            logging.warning(f'Dataset {func} failed with exception:\n\t{traceback.format_exc()}\n')
            return 0


def clean(instr):
    instr = instr.replace('-', '_')
    return instr


def main(datasets, scene_graph_file, output_dir, run_all):
    from multiprocessing import Pool, cpu_count
    program = DatasetsProgram(scene_graph_file, output_dir)
    if run_all:
        datasets = program.valid_datasets
    else:
        datasets = set(map(clean, datasets))
    invalid_datasets = datasets.difference(program.valid_datasets)
    if invalid_datasets:
        for dataset in invalid_datasets:
            logging.error('Dataset with name %s is not valid. Aborting.' % dataset)
        raise ValueError('Found invalid dataset name in argument list.')
    datasets = list(sorted(datasets))
    failed_datasets = list()
    success_datasets = list()
    # pool = Pool(processes=cpu_count())
    # output = pool.map(program._run, datasets)
    # print(output)
    # for ix, code in output:
    #     if code == 1:
    #         success_datasets.append(datasets[ix])
    #     else:
    #         failed_datasets.append(datasets[ix])
    for dataset in sorted(datasets):
        try:
            random.seed(42)
            logging.info('Starting generation for dataset: %s' % dataset)
            getattr(program, dataset)()
            success_datasets.append(dataset)
        except:
            logging.warning(f'Dataset {dataset} failed with exception:\n\t{traceback.format_exc()}\n')
            failed_datasets.append(dataset)
    logging.info('\n'.join(wrap(f'Finished datasets: {", ".join(success_datasets)}',
                                width=100, subsequent_indent=' ' * 4)))
    if len(failed_datasets) > 0:
        logging.warning('\n'.join(wrap(f'Failed on datasets: {", ".join(failed_datasets)}',
                                       width=100, subsequent_indent=' ' * 4)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('datasets', type=str, nargs='*', help='List of datasets to generate.')
    parser.add_argument('--scene-graphs', type=str, help='Path to scene graphs dict json file.')
    parser.add_argument('--output-dir', type=str, help='Path to output directory. Will be created if '
                                                       'doesn\'t exist.')
    parser.add_argument('--all', action='store_true')

    args = parser.parse_args()

    log_format = f'%(asctime)s %(levelname)s: %(message)s'
    logging.basicConfig(format=log_format, level=logging.DEBUG)

    scene_graph_file = Path(args.scene_graphs)
    if not scene_graph_file.is_file():
        raise FileNotFoundError('Scene Graphs file %s passed not found. Aborting.' % scene_graph_file.as_posix())
    output_dir = Path(args.output_dir)
    if not output_dir.is_dir():
        try:
            output_dir.mkdir(parents=True)
        except Exception as e:
            logging.error(f'Cannot create directory with path %s .' % output_dir.as_posix())
            raise e

    main(args.datasets, scene_graph_file, output_dir, args.all)
