# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0

import argparse

from da4er.basic_augmentations import LexAugmentation, SpellAugmentation, CharacterAugmentation, GPTAugmentation, OPTAugmentation, ParaAugmentation, BackTranslationAugmentation
from da4er.formats.blink_converter import BlinkInputConverter, BlinkOutputConverter
from da4er.utils import InputProcessor, InputConverter, OutputConverter

RECOMMENDED_CHARACTER_METHODS = ["insert", "substitute", "swap", "delete"]
RECOMMENDED_AUG_NAME = 'recommended'


def main(args):
    data = []

    # Initialising augmentations
    for augmentation in args.aug:
        if augmentation == 'lexical':
            data.append(LexAugmentation(args.src_lex, args.src_lang_lex))
        elif augmentation == 'spelling':
            data.append(SpellAugmentation())
        elif augmentation == 'character':
            character_augmentations = set(args.chr_specific)
            if RECOMMENDED_AUG_NAME in character_augmentations:
                character_augmentations.update(RECOMMENDED_CHARACTER_METHODS)
                character_augmentations.remove(RECOMMENDED_AUG_NAME)
            # for char_aug in character_augmentations:
            data.append(CharacterAugmentation(character_augmentations))
        elif augmentation == 'gpt':
            data.append(GPTAugmentation())
        elif augmentation == 'opt':
            data.append(OPTAugmentation())
        elif augmentation == 'para':
            data.append(ParaAugmentation())
        elif augmentation == 'back':           
            data.append(BackTranslationAugmentation())





    # Sanity checks
    if len(data) == 0:
        print('Wrong augmentation method. '
              'Please try one of lexical / spelling / character / translation / back-translation')
        return

    # Initialising formatters
    if args.format == "blink":
        input_converter = BlinkInputConverter()
        output_converter = BlinkOutputConverter()
    else:
        input_converter = InputConverter()
        output_converter = OutputConverter()

    # Running augmentations
    print("Running data augmentation for input file \"%s\" with the following augmentations (%d):"
          % (args.input, len(data)))
    for augmentation in data:
        print("  %s" % str(augmentation))

    processor = InputProcessor(data, input_converter=input_converter, output_converter=output_converter)
    processor.augment_file(args.input, args.output)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Conventional augmentations',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input', type=str, required=True, help='Original text data.')
    parser.add_argument('--aug', type=str, required=True, nargs='+',
                        help='Choose augmentation methods: '
                             'lexical / spelling / character / gpt / opt / para / back.')
    parser.add_argument('--output', type=str, required=True, help='Name of output file.')
    parser.add_argument('--src-lex', type=str, default='wordnet', help='Source for lexical.')
    parser.add_argument('--src-lang-lex', type=str, default='eng', help='Source language for lexical.')
    parser.add_argument('--chr-specific', type=str, nargs='+', default=[RECOMMENDED_AUG_NAME],
                        help='Select insert, substitute, swap or delete method in character-level')

    # parser.add_argument('--src-nmt', type=str, help='Source NMT location')
    # parser.add_argument('--tar-nmt', type=str, help='Target NMT location')
    # parser.add_argument('--src', type=str, help='Source language for translation / back-translation')
    # parser.add_argument('--inter', type=str, default='en',
    #                     help='Intermediate language for translation / back-translation')
    # parser.add_argument('--tar', type=str, default='fr', help='Target language for translation')
    parser.add_argument('--format', type=str,
                        default='default',
                        choices=['default', 'blink'],
                        help='What format to expect as input'
                        )

    args = parser.parse_args()
    main(args)
