# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  
# SPDX-License-Identifier: CC-BY-NC-4.0

import argparse
from da4er.gan.preprocessing import GANPreprocessor


def add_gan_preprocessor_args(parser: argparse.ArgumentParser):
    # parser.add_argument('--model-loc', type=str, required=True, help='Pre-trained model location.')
    # parser.add_argument('--ft-lang', type=str, default='french', help='Language used in FastText')
    parser.add_argument('--dimension', type=int, default=6, help='Number of tokens considered in BART-GAN')
    parser.add_argument('--hidden', type=int, default=768, help='Hidden unit for BART-GAN.')


def build_gan_preprocessor(args) -> GANPreprocessor:
    # if args.gan == 'fasttext':
    #     from da4er.gan.preprocessing.fasttext_preprocessor import FastTextGAN
    #     data = FastTextGAN(args.model_loc, args.ft_lang)
    # elif args.gan == 'bert':
    #     from da4er.gan.preprocessing.transformers_preprocessor import BERTGAN
    #     data = BERTGAN(args.model_loc)
    if args.gan == 'bart':
        from da4er.gan.preprocessing.transformers_preprocessor import BARTGAN
        data = BARTGAN(args.dimension, args.hidden) # args.model_loc,
    else:
        raise ValueError("Incorrect GAN type")

    return data
