import argparse
import logging
import json
import numpy as np

from pathlib import Path


# class IdentityDict(dict):
#     def __missing__(self, key):
#         return key
#

def main(input_file, output_file, with_perturbed_images):
    questions = json.load(input_file.open())

    # if img_code_mapping is None:
    #     img_code_mapping = IdentityDict()
    mmf_questions = [
        {'create_time': '2018-03-29 16:39', 'dataset_name': 'gqa',
         'version': 1, 'has_answer': True, 'has_gt_layout': False},
    ]

    for ix, question in enumerate(sorted(questions, key=lambda x: x['question_id'])):  # sort by question_id added 2020-11-13
        img_id = question['img_id']
        sent = question['sent']
        answer_str = list(question['label'].keys())[0]

        question_dict = {
            'json_question_id': question['question_id'],
            'image_name': ix if with_perturbed_images else img_id,
            #'image_id': ix if with_perturbed_images else int(img_id),
            'image_id': ix if with_perturbed_images else img_id,
            'question_id': ix,
            'feature_path': (str(ix) if with_perturbed_images else img_id) + '.npy',
            'question_str': sent,
            'question_tokens': sent.lower().strip('?.,\'":;').split(' '),
            'all_answers': [answer_str for _ in range(10)],
            'answers': [answer_str for _ in range(10)]
        }
        mmf_questions.append(question_dict)
    np.save(output_file, np.array(mmf_questions))
    logging.info(f'Saved to {output_file} .')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--input-file', type=str, required=True,
    #                     help='Filepath to input questions JSON file.')
    parser.add_argument('--input-files', type=str, nargs='+', required=True,
                        help='Filepaths to input questions JSON file.')
    # parser.add_argument('--image-code-mapping', type=str, required=False,
    #                     help='Filepath to image code mapping file.')
    parser.add_argument('--with-perturbed-images', action='store_true',
                        help='Including sets image id\'s to question_id')
    parser.add_argument('--output-dir', type=str, required=False,
                        help='Filepath to output directory.')

    args = parser.parse_args()
    input_files = list(map(Path, args.input_files))
    for input_file in input_files:
        if not input_file.exists():
            raise FileExistsError('input file path %s does not exist. ' +
                                  'Aborting.' % input_file.as_posix())
    if args.output_dir:
        output_dir = Path(args.output_dir)
        if output_dir.is_file():
            raise FileExistsError('Output directory %s is a file. Aborting.' % output_dir.as_posix())
        output_dir.mkdir(parents=True, exist_ok=True)
    else:
        output_dir = 'mmf_datasets'

    # if args.image_code_mapping_file:
    #     image_code_mapping_file = Path(args.image_code_mapping)
    #     if not image_code_mapping_file.exists():
    #         raise FileExistsError('image code mapping file path %s does not exist. ' +
    #                               'Aborting.' % image_code_mapping_file)
    #     image_code_mapping = json.load(image_code_mapping_file.open())
    # else:
    #     image_code_mapping = None
    log_format = '%(asctime)s %(levelname)s: %(message)s'
    logging.basicConfig(format=log_format, level=logging.DEBUG)

    output_files = list()
    for input_file in input_files:
        output_file = Path(output_dir, input_file.stem).with_suffix('.npy')
        if output_file.exists():
            raise FileExistsError('output file path %s already exists. ' +
                                  'Won\'t overwrite an existing file: %s. Aborting.' % output_file)
        output_files.append(output_file)
    for input_file, output_file in zip(input_files, output_files):
        main(input_file, output_file, args.with_perturbed_images)
