""" produce augmented samples """
import argparse
import os
import logging
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')

import pandas as pd

from nlpaug_util import augment_by_nlpaug
from textattack_util import augment_by_textattack
from data_util import get_data


def run_augmentation(method: str,
                     transformations_per_example: int,
                     output_dir: str,
                     data: str,
                     max_char_length: int = 512,
                     data_dir: str = './dataNew'):
    """ produce augmented samples

    :param method:
    :param transformations_per_example:
    :param output_dir:
    :param max_char_length:
    :param dataNew:
    """
    logging.info('\n *** run_augmentation ***\n - method: {}\n - num per sample: {}'.
                 format(method, transformations_per_example))
    sentences, meta_dict = get_data(data, data_dir=data_dir)
    shared_config = dict(method=method, transformations_per_example=transformations_per_example)
    if method in ['word_embedding', 'synonym']:
        augmenter = augment_by_textattack
    elif method in ['bert', 'bt']:
        augmenter = augment_by_nlpaug
    else:
        raise ValueError('unknown method: {}'.format(method))

    logging.info('start augmentation: {} sentences'.format(len(sentences)))
    #sentences = [i if len(i) < max_char_length else None for i in sentences]

    pair_sent = [(i, '') if len(i) < max_char_length else (i.split('.')[0], i.replace(i.split('.')[0], '')) for i in sentences]
    sentences = list(list(zip(*pair_sent))[0])
    augmented = augmenter(sentences, **shared_config)

    # add the remaining text part
    augmented = [[_a + r for _a in a] for a, (_, r) in zip(augmented, pair_sent)]

    logging.info('format as a csv')
    new_meta_dict = {k: [] for k in meta_dict.keys()}
    new_meta_dict['clData_aug'] = []
    for n, aug in enumerate(augmented):
        logging.info(' - {}th sentence has {} augmentations'.format(n, len(aug)))
        if len(aug) == 0:
            raise ValueError('no augmentations found')
        for k, v in meta_dict.items():
            new_meta_dict[k] += [v[n]] * len(aug)
            print("aug:", aug)
        new_meta_dict['clData_aug'] += aug

    file_name = '{}.{}.csv'.format(data, method)
    logging.info('export to {}/{}'.format(output_dir, file_name))
    os.makedirs(output_dir, exist_ok=True)
    pd.DataFrame(new_meta_dict).to_csv('{}/{}'.format(output_dir, file_name))


def get_options():
    parser = argparse.ArgumentParser(description='Benchmark algorithms on SemEval2010')
    parser.add_argument('-d', '--data', help='dataset name', default='comments', type=str)
    parser.add_argument('-o', '--output-dir', help='output dir', default='./outputs2', type=str)
    parser.add_argument('-m', '--method', help='method', default='word_embedding', type=str)
    parser.add_argument('-n', '--num', help='augment per sample', default=2, type=int)
    parser.add_argument('--max-char-length', help='max char length', default=512, type=int)
    return parser.parse_args()


if __name__ == '__main__':
    opt = get_options()
    run_augmentation(
        method=opt.method,
        transformations_per_example=opt.num,
        output_dir=opt.output_dir,
        data=opt.data,
        max_char_length=opt.max_char_length)
