# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0


import glob
import tensorflow as tf
import argparse
import json
import numpy as np
from tqdm import tqdm

from da4er.formats.blink_converter import BlinkInputConverter, BlinkOutputConverter
from da4er.utils import InputConverter, OutputConverter, InputProcessor
from da4er.gan.utils import build_gan_preprocessor, add_gan_preprocessor_args
from da4er.gan.transformers_bart import BARTAUG

BATCH_SIZE = 50000

def generate_samples(input_path: str):
    for file in glob.glob(input_path):
        with open(file, 'r') as reader:
            for line in reader:
                yield line

def main(args):

    gan_preprocessor = build_gan_preprocessor(args)

    # Initialising formatters
    if args.format == "blink":
        input_converter = BlinkInputConverter()
        output_converter = BlinkOutputConverter()
    else:
        input_converter = InputConverter()
        output_converter = OutputConverter()

    batches = []
    embeddings = []
    skipped = 0
    total = 0
    for line in tqdm(generate_samples(args.input), desc="Extracting text for BART augmentation"):
        sample = input_converter.process(line)
        cur_texts = [sample.work_sample.original.query]
        if not args.query_only:
            cur_texts += sample.work_sample.original.entities
        cur_embeddings = [gan_preprocessor.preprocess(text) for text in cur_texts]

        for embedding in cur_embeddings:
            total += 1
            if embedding is None:
                skipped += 1
                continue
            embeddings.append(embedding)
        if len(embeddings) > BATCH_SIZE:
            batch = np.array(embeddings, dtype=np.float32)
            batches.append(batch)
            embeddings = []

    if len(embeddings) > 0:
        batch = np.array(embeddings, dtype=np.float32)
        batches.append(batch)
    del embeddings
    batches = np.vstack(batches)

    print("Processed %d embeddings (skipped %.2f%%)" % (total, 100 * float(skipped) / total))
    loc_tmp_bart = '../paper_work/data-augmentation-for-entity-resolution/tmp_bart/'
    np.savez(loc_tmp_bart + str('tmp_bart'), train_data=batches)
    tmp_arg_input = args.input
    args.input = loc_tmp_bart + str('tmp_bart.npz')

    preprocessor = build_gan_preprocessor(args)

    data = BARTAUG(preprocessor, args.dimension, args.bart_score_low, args.bart_score_high)
    data = [data]

    # Running augmentations
    for augmentation in data:
        print("  %s" % str(augmentation))
    processor = InputProcessor(data, input_converter=input_converter, output_converter=output_converter)
    processor.augment_file(tmp_arg_input, args.output)

    # print("Running data augmentation for input file \"%s\" with the following augmentations (%d):"
    #       % (args.input, len(data)))
    # # Running augmentations
    # for augmentation in data:
    #     print("  %s" % str(augmentation))
    #
    # processor = InputProcessor(data, input_converter=input_converter, output_converter=output_converter)
    # # processor.augment_file(args.input, args.output)
    # processor.augment_file(tmp_arg_input, args.output)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Generating synthetic data from BART',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', type=str, required=True, help='Original text data. Only receive json file.')
    parser.add_argument('--output', type=str, required=True, help='Name of output file. Only generate json file.')
    parser.add_argument('--query-only', action="store_true", help='Only use queries')
    parser.add_argument('--bart-score-low', type=float, default=0.4,
                        help='Low threshold for measuring similarity in BART-GAN')
    parser.add_argument('--bart-score-high', type=float, default=0.9,
                        help='High threshold for measuring similarity in BART-GAN')
    parser.add_argument('--dimension', type=int, default=6, help='Number of tokens considered in BART')
    parser.add_argument('--gan', type=str, required=True, help='Choose models: fasttext / bert / bart.')
    # parser.add_argument('--model-loc', type=str, required=True, help='Pre-trained model location.')
    parser.add_argument('--hidden', type=int, default=768, help='Hidden unit for BART-GAN.')
    parser.add_argument('--format', type=str,
                        default='default',
                        choices=['default', 'blink'],
                        help='What format to expect as input'
                        )
    args = parser.parse_args()
    main(args)
