# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0

import tensorflow as tf
import argparse
import json

from da4er.formats.blink_converter import BlinkInputConverter, BlinkOutputConverter
from da4er.gan.utils import add_gan_preprocessor_args, build_gan_preprocessor
from da4er.utils import InputConverter, OutputConverter, InputProcessor


def main(args):
    print("Loading the GAN preprocessor")
    preprocessor = build_gan_preprocessor(args)
    print("Loading the GAN model")
    generator = tf.keras.models.load_model(args.gan_loc)

    data = None

    # if args.gan == 'fasttext':
    #     from da4er.gan.fasttext_gan import FastTextGAN
    #     data = FastTextGAN(preprocessor, generator, args.ft_score)
    # elif args.gan == 'bert':
    #     from da4er.gan.transformers_gan import BERTGAN
    #     data = BERTGAN(preprocessor, generator, args.dict_bert)
    if args.gan == 'bart':
        from da4er.gan.transformers_gan import BARTGAN
        data = BARTGAN(preprocessor, generator, args.dimension,
                       args.bart_score_low, args.bart_score_high)

    if data is None:
        # print('Wrong GAN method. Please try one of fasttext / bert / bart')
        print('Wrong GAN method. Please try bart')
    data = [data]

    # Initialising formatters
    if args.format == "blink":
        input_converter = BlinkInputConverter()
        output_converter = BlinkOutputConverter()
    else:
        input_converter = InputConverter()
        output_converter = OutputConverter()

    # Running augmentations
    print("Running data augmentation for input file \"%s\" with the following augmentations (%d):"
          % (args.input, len(data)))
    for augmentation in data:
        print("  %s" % str(augmentation))

    processor = InputProcessor(data, input_converter=input_converter, output_converter=output_converter)
    processor.augment_file(args.input_json, args.output)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='GAN training and generating synthetic data',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', type=str, required=True, help='Original embedding data. Only receive npz file.')
    parser.add_argument('--input-json', type=str, required=True, help='Original text data. Only receive json file.')
    parser.add_argument('--gan', type=str, required=True, help='Choose GAN methods: bart.')
    parser.add_argument('--output', type=str, required=True, help='Name of output file. Only generate json file.')
    # parser.add_argument('--dict-bert', type=str, help='Dictionary for BERT from translated traffic')
    parser.add_argument('--gan-loc', type=str, required=True, help='Location of the gan model')
    # parser.add_argument('--ft-score', type=float, default=1.3,
    #                     help='Threshold for measuring similarity in FastText-GAN')
    parser.add_argument('--bart-score-low', type=float, default=0.4,
                        help='Low threshold for measuring similarity in BART-GAN')
    parser.add_argument('--bart-score-high', type=float, default=0.9,
                        help='High threshold for measuring similarity in BART-GAN')
    parser.add_argument('--format', type=str,
                        default='default',
                        choices=['default', 'blink'],
                        help='What format to expect as input'
                        )
    add_gan_preprocessor_args(parser)

    args = parser.parse_args()
    main(args)
