from typing import List, Tuple

from args import ProgramArgs
from utils.mask import mask_sentence
from utils.safer import WordSubstitude


class Augmentor(object):
    def __init__(self, args: ProgramArgs):
        self.training_type = args.training_type
        self.mask_token = '[MASK]'
        self.safer_aug_set = f'{args.workspace}/cache/embed/perturbation_constraint_pca0.8_100.pkl'
        self.aug = WordSubstitude(self.safer_aug_set)

    def augment(self, sentence: str, n: int) -> List[str]:
        if self.training_type == 'mask':
            return mask_sentence(sentence, 0.7, self.mask_token, n)
        elif self.training_type == 'dne':
            return [sentence] * n
        elif self.training_type == 'safer':
            return self.aug.get_perturbed_batch(sentence.lower(), rep=n)
        else:
            return [sentence]


