import argparse
import json
import logging
import pathlib
import pprint
import random
import re
from math import inf
from copy import deepcopy
from glob import glob
from collections import defaultdict, Counter

from tqdm import tqdm

import engine
import script_utils


class QuestionFamily:
    def __init__(
            self,
            scenes_dict: dict,
            question_family: dict,
    ):
        self.scenes_dict = scenes_dict
        self.program = question_family['program']
        self.templates = list(question_family['templates'])
        self.question_type = question_family['name']
        self.generator = engine.QGenerator(self.question_type)
        self.tokens = question_family['tokens']
        self.token_numbers = set()
        for token in self.tokens:
            match = re.search(engine.utils.TOKEN, token)
            if match:
                self.token_numbers.add(match.group(4))
        self.token_numbers = list(self.token_numbers)
        self.constraints = question_family['constraints']

        logging.info(f'\tConstraints:' + '\n\t' + pprint.pformat(self.constraints).replace('\n', '\n\t'))
        logging.info(f'\tProgram:' + '\n\t' + pprint.pformat(self.program).replace('\n', '\n\t'))
        logging.info(f'\tTemplates:' + '\n\t' + pprint.pformat(self.templates).replace('\n', '\n\t'))

    def generate_all_questions(
            self,
            negative_sampling_method: str,
            max_assignments_per_scene: int,
            max_templates_per_assignment: int,
            balance_answers: bool,
            include_ids: bool,
    ):
        def get_max_attributes(objects):
            if len(objects) == 0:
                return 0
            return max(map(lambda x: len(x['attributes']), objects.values()))
        logging.info(f'Starting question generation.')
        all_questions = dict()
        for iix, (img_id, scene) in enumerate(tqdm(list(self.scenes_dict.items()))):
            f_ix = 0
            if len(scene['objects']) < 7:
                continue
            if get_max_attributes(scene["objects"]) > 15:
                continue
            all_assignments = self.generator.generate_assignments(scene, self.tokens, self.constraints,
                                                                  include_ids=include_ids,
                                                                  img_id=img_id,
                                                                  max_num=max_assignments_per_scene)
            negative_sampled_questions = list()
            if negative_sampling_method:
                if negative_sampling_method == 'swap-scenes':
                     negative_sampled_questions.extend(self.generate_swapped_assignments(
                         all_assignments=all_assignments,
                         this_img_id=img_id,
                         max_num=max_assignments_per_scene,
                         include_ids=include_ids))
                elif negative_sampling_method == 'negative-context-object':
                    negative_sampled_questions.extend(self.generate_negative_context_object_assignments(
                        all_assignments=all_assignments,
                        img_id=img_id,
                        include_ids=include_ids))
                elif negative_sampling_method == 'negative-color-token':
                    negative_sampled_questions.extend(self.generate_negative_color_token_assignments(
                        all_assignments=all_assignments,
                        img_id=img_id,
                        include_ids=include_ids))
                else:
                    raise ValueError('negative_sampling_method %s unknown. Aborting' % negative_sampling_method)
                all_assignments.extend(negative_sampled_questions)
            # sample max after full generation for greater diversity
            assignments = random.sample(all_assignments, k=min(len(all_assignments), max_assignments_per_scene))
            for q_id, question_dict in self.generate_scene_questions(scene, img_id, assignments,
                                                                     max_templates=max_templates_per_assignment):
                all_questions[q_id] = question_dict
        if balance_answers:
            logging.info('Balancing answers...')
            balanced_ids = script_utils.get_balanced_question_ids(all_questions)
            remove_keys = set(all_questions.keys()).difference(balanced_ids)
            for key in remove_keys:
                del all_questions[key]
        all_questions = list(all_questions.values())
        logging.info(f'Generated {len(all_questions)} questions.')
        most_common_answers = Counter([list(v['label'].keys())[0] for v in all_questions]).most_common()
        logging.info(f'Answers distribution: {most_common_answers}')
        return all_questions

    # def generate_swapped_assignments(self, all_assignments, this_img_id, include_ids, max_num):
    #     new_assignments = list()
    #     non_id_assignments = list(map(lambda x: self.get_sub_dict(x, self.tokens), all_assignments))
    #     random_img_id = random.sample(self.scenes_dict.keys() - {this_img_id}, k=1)[0]
    #     for assignment in self.random_scene_assignments(random_img_id, include_ids=include_ids, max_num=max_num):
    #         if self.get_sub_dict(assignment, self.tokens) not in non_id_assignments:
    #             new_assignments.append(assignment)
    #     return new_assignments

    def generate_negative_context_object_assignments(self, all_assignments, img_id, include_ids):
        if len(all_assignments) == 0:
            return []

        def get_attribute_assignment(obj_name, object_number):
            new_sub_assignment = {f'obj{object_number}': obj_name}
            if include_ids:
                new_sub_assignment[f'obj{object_number}_id'] = None
            if f'attrs{object_number}' in self.tokens:
                attrs = engine.utils.get_random_attrs(obj_name)
                new_sub_assignment[f'attrs{object_number}'] = attrs
            return new_sub_assignment
        new_assignments = list()
        k = len(all_assignments)
        scene = self.scenes_dict[img_id]
        scene_object_names = {obj['name'] for obj in scene['objects'].values()}
        context_objects = engine.utils.get_random_obj_name(scene_object_names, k)
        # sub_assignments = random.sample(all_assignments, k=len(context_objects))
        for name, old_assignment in zip(context_objects, all_assignments[:min(len(context_objects), len(all_assignments))]):
            new_assignment = deepcopy(old_assignment)
            object_number = random.choice(self.token_numbers)
            new_assignment.update(get_attribute_assignment(name, object_number))
            new_assignments.append(new_assignment)
        return new_assignments

    # def random_scene_assignments(
    #         self,
    #         random_img_id,
    #         include_ids,
    #         max_num,
    # ):
    #     random_scene = self.scenes_dict[random_img_id]
    #     for assignment in self.generator.generate_assignments(random_scene, self.tokens, self.constraints,
    #                                                           include_ids=include_ids,
    #                                                           max_num=max_num):
    #         assignment['orig_img_id'] = random_img_id
    #         yield assignment

    @staticmethod
    def get_sub_dict(full_dict, keys):
        return {key: full_dict[key] for key in keys}

    def generate_scene_questions(
            self,
            scene,
            img_id,
            assignments,
            max_templates
    ):
        num_templates = len(self.templates)
        sample_num = min(max_templates, num_templates)
        for a_ix, initial_assignment in enumerate(assignments):
            answer = self.generator.handler.get_answer(scene, self.program, initial_assignment)

            templates = random.sample(range(num_templates), k=sample_num)
            answer_dict = {answer: 1.0}
            for t_ix in templates:
                template = self.templates[t_ix]
                sent = self.generator.expand_text_template_multi(template, initial_assignment)
                q_id = f'{img_id}-{a_ix}-{t_ix}'
                question_dict = {
                    'img_id': img_id,
                    'label': answer_dict,
                    'question_id': q_id,
                    'sent': sent,
                    'question_type': self.question_type,
                    'assignment': initial_assignment
                }
                yield q_id, question_dict

    # def _get_color_assignment(self, scene, attribute_map, assignment, object_number):
    #     obj_name = assignment[f'obj{object_number}']
    #     color_key = f'color{object_number}'
    #     new_assignment = dict()
    #     try:
    #         obj_ids = attribute_map[obj_name, tuple(sorted(assignment[f'attrs{object_number}']))]
    #         exclude_colors = {s for l in [engine.utils.get_color(scene['objects'][obj_id]) for obj_id in obj_ids] for s
    #                           in l}
    #         # exclude_colors = set(reduce(lambda x, y: x |= set(y), ))
    #     except:
    #         exclude_colors = set(assignment[color_key], )
    #     if color_key in self.tokens:
    #         color = engine.utils.get_random_color(obj_name, exclude=exclude_colors)
    #         new_assignment[color_key] = color
    #     return new_assignment

    def generate_negative_color_token_assignments(self, all_assignments, img_id, include_ids):
        if len(all_assignments) == 0:
            return []
        new_assignments = list()
        scene = self.scenes_dict[img_id]
        attribute_map = engine.utils.get_attribute_map(scene['objects'])
        for assignment in all_assignments:
            new_assignment = assignment.copy()
            object_number = random.choice(self.token_numbers)
            new_assignment.update(self._get_color_assignment(scene, attribute_map, assignment, object_number))
            new_assignments.append(new_assignment)
        return new_assignments


def get_default_output_file(scene_file, question_type, seed, balance_answers, negative_sampling_method,
                            max_assignments_per_scene, max_templates_per_assignment, include_ids, suffix=None):
    output_path = question_type
    output_path += f'_seed{seed}'
    output_path += f'_balanced' if balance_answers else ''
    output_path += f'_{negative_sampling_method}' if negative_sampling_method else ''
    output_path += '_with_obj_ids' if include_ids else ''
    output_path += f'_max_assignments{max_assignments_per_scene}'
    output_path += f'_max_templates{max_templates_per_assignment}'
    output_path += f'_{scene_file}'
    output_path += f'_{suffix}' if suffix else ''
    output_path += '.json'
    return output_path


def generate_precursor_data(
        question_family,
        scene_graphs,
        negative_sampling_method,
        max_assignments_per_scene,
        max_templates_per_assignment,
        balance_answers,
        include_ids,
        seed,
):
    random.seed(seed)

    logging.info('Loading scene graphs...')
    scenes = dict()
    for scene_file in scene_graphs:
        logging.info(f'\t{scene_file}')
        scenes.update(json.load(open(scene_file)))
    logging.info(f'Loaded {len(scenes)} scenes total.\n')

    question_family = json.load(open(question_family))
    logging.info(f'Loaded question family: {question_family["name"]}')

    generator = QuestionFamily(scenes, question_family)
    all_questions = generator.generate_all_questions(negative_sampling_method=negative_sampling_method,
                                                     max_assignments_per_scene=max_assignments_per_scene,
                                                     max_templates_per_assignment=max_templates_per_assignment,
                                                     balance_answers=balance_answers, include_ids=include_ids)
    return all_questions


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--question-family-file', type=str, required=True,
                        help='Filepath to question family.')
    parser.add_argument('--scene-files', nargs='+', required=True,
                        help='Scene graph files used to generate dataset.')
    parser.add_argument('--random-seed', type=int, required=False, default=42,
                        help='Initialize random seed for random sampling.')
    parser.add_argument('--max-assignments-per-scene', type=int, required=False, default=inf,
                        help='Maximum number of assignments allowed per scene. Set low number (4 to 8) to reduce '
                             'number of questions produced.')
    parser.add_argument('--max-templates-per-assignment', type=int, required=False, default=inf,
                        help='Maximum number of templates allowed per assignment. Set low number (2 or 3) to reduce '
                             'number of questions produced.')
    output_group = parser.add_mutually_exclusive_group()
    output_group.add_argument('--output-dir', type=str,
                              help='Filepath to output file.')
    output_group.add_argument('--output-file', type=str,
                              help='Filepath to output file.')
    parser.add_argument('--negative-sampling-method', type=str,
                        help='Method for sampling likely negative assignments.')
    parser.add_argument('--balance-answers', action='store_true',
                        help='Down-sample more common answers such that all answers are equally represented.')

    args = parser.parse_args()

    question_family_file = args.question_family_file
    scene_graph_files = args.scene_files

    if args.output_file:
        output_filename = pathlib.Path(args.output_file).name
        assert output_filename.endswith('.json'), ValueError('Output filename %s must end with \'.json\'. Aborting.' %
                                                             output_filename)
        output_directory = pathlib.Path(args.output_file).absolute().parent.as_posix()
        script_utils.verify_directories([output_directory])  # throws exception if directory doesn't exist
    else:
        output_filename = None
        output_directory = args.output_dir if args.output_dir else pathlib.Path('./').absolute().as_posix()

    script_utils.verify_files([question_family_file, *scene_graph_files])  # throws exception if doesn't exist

    log_format = f'%(asctime)s %(levelname)s: %(message)s'
    logging.basicConfig(format=log_format, level=logging.DEBUG)

    precursor_data = generate_precursor_data(
        question_family=question_family_file,
        scene_graphs=scene_graph_files,
        seed=args.random_seed,
        max_assignments_per_scene=args.max_assignments_per_scene,
        max_templates_per_assignment=args.max_templates_per_assignment,
        negative_sampling_method=args.negative_sampling_method,
        balance_answers=args.balance_answers,
        include_ids=True
    )

    if not output_filename:
        scene_file_stem = '_'.join(list(map(lambda x: pathlib.Path(x).stem, scene_graph_files)))
        output_filename = get_default_output_file(scene_file=scene_file_stem,
                                                  question_type=pathlib.Path(args.question_family_file).stem,
                                                  seed=args.random_seed, include_ids=True,
                                                  balance_answers=args.balance_answers,
                                                  negative_sampling_method=args.negative_sampling_method,
                                                  max_assignments_per_scene=args.max_assignments_per_scene,
                                                  max_templates_per_assignment=args.max_templates_per_assignment)
    output_path = pathlib.Path(output_directory, output_filename)
    output_path.mkdir(parents=True, exist_ok=True)
    output_path = script_utils.get_numbered_path(output_path)
    logging.info('Saving JSON questions to %s .' % output_path)
    json.dump(precursor_data, open(output_path, 'w+'))
