""" Augmentation by nlpaug """
# import os
from typing import List
from tqdm import tqdm

import torch
import nlpaug.augmenter.word as naw

# os.environ["MODEL_DIR"] = '~/.cache'


def drop_duplicate(_list):
    return list(set(_list))


def augment_by_nlpaug(samples: List or str,
                      method: str = 'bert',
                      transformers_name: str = 'bert-base-uncased',
                      transformations_per_example: int = 3):
    """ Augmentation by nlpaug

    :param samples:
    :param method: `bert` or `bt`
    :param transformers_name:
    :param transformations_per_example:
    :return: a nested list where each list corresponds to the augmented sentences for original sentence
    """
    if torch.cuda.device_count() > 0:
        device = 'cuda'
    else:
        device = 'cpu'

    if method == 'bert':
        # https://github.com/makcedward/nlpaug/blob/master/nlpaug/augmenter/word/context_word_embs.py
        aug = naw.ContextualWordEmbsAug(
            model_path=transformers_name,
            action="substitute",
            device=device)
    elif method == 'bt':
        # https://github.com/makcedward/nlpaug/blob/master/nlpaug/augmenter/word/back_translation.py
        aug = naw.BackTranslationAug(
            from_model_name='transformer.wmt19.en-de',
            to_model_name='transformer.wmt19.de-en',
            device=device)
    else:
        raise ValueError('unknown method: {}'.format(method))
    if type(samples) is str:
        if method != 'bt':
            return drop_duplicate([aug.augment(samples) for _ in tqdm(range(transformations_per_example))])
        else:
            augmented_samples = []
            for i in range(transformations_per_example):
                samples = aug.augment(samples)
                augmented_samples.append(samples)
            return drop_duplicate(augmented_samples)

    else:
        if method != 'bt':
            return [
                drop_duplicate([aug.augment(i) for _ in range(transformations_per_example)]) if i is not None else []
                    for i in tqdm(samples)]
        else:
            full_augmented_samples = []
            for s in tqdm(samples):
                augmented_samples = []
                for i in range(transformations_per_example):
                    s = aug.augment(s)
                    augmented_samples.append(s)
                full_augmented_samples.append(drop_duplicate(augmented_samples))
            return full_augmented_samples


