""" Augmentation by TextAttack """
from typing import List
from tqdm import tqdm

from textattack import augmentation

# os.environ["TA_CACHE_DIR"] = '~/.cache'
# os.makedirs('~/.cache', exist_ok=True)


def augment_by_textattack(samples: List or str,
                          method: str = 'embedding',
                          pct_words_to_swap: float = 0.1,
                          transformations_per_example: int = 5):
    """ Augmentation by TextAttack

    :param samples:
    :param method: `word_embedding` or `synonym`
    :param pct_words_to_swap:
    :param transformations_per_example:
    :return: a nested list where each list corresponds to the augmented sentences for original sentence
    """

    if method == 'word_embedding':
        # https://github.com/QData/TextAttack/blob/master/textattack/augmentation/recipes.py#L112
        augmenter = augmentation.EmbeddingAugmenter(
            pct_words_to_swap=pct_words_to_swap, transformations_per_example=transformations_per_example)
    elif method == 'synonym':
        # https://github.com/QData/TextAttack/blob/master/textattack/augmentation/recipes.py#L94
        augmenter = augmentation.WordNetAugmenter(
            pct_words_to_swap=pct_words_to_swap, transformations_per_example=transformations_per_example)
    else:
        raise ValueError('unknown method: {}'.format(method))
    if type(samples) is str:
        return augmenter.augment(samples)
    else:
        return [augmenter.augment(i) if i is not None else [] for i in tqdm(samples)]
