# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  
# SPDX-License-Identifier: CC-BY-NC-4.0

import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char.random as nac
import random
from da4er.augmentation import BaseAugmentation


class LexAugmentation(BaseAugmentation):
    """
    Word-level augmentation with synonyms
    """
    def __init__(self, src_aug='wordnet', src_lang='eng'):
        super().__init__()
        self.lex_aug = naw.SynonymAug(aug_src=src_aug, lang=src_lang, aug_p=0.1)

    def augment(self, txt: str) -> str:
        return self.lex_aug.augment(txt)[0]


class SpellAugmentation(BaseAugmentation):
    """
    Spelling-based augmentation wrapper
    """
    def __init__(self):
        super().__init__()
        self.spel_aug = naw.SpellingAug(aug_p=0.1)

    def augment(self, txt: str) -> str:
        return self.spel_aug.augment(txt)[0]


class CharacterAugmentation(BaseAugmentation):
    """
    Character-based augmentation wrapper
    """
    def __init__(self, method):
        super().__init__()

    def augment(self, txt: str) -> str:
        RECOMMENDED_CHARACTER_METHODS = ["insert", "substitute", "swap", "delete"]
        chr_act = RECOMMENDED_CHARACTER_METHODS[random.randint(0,3)]
        chr_aug = nac.RandomCharAug(action=chr_act, aug_char_p=0.1, include_upper_case=False, include_numeric=False, spec_char=' ')
        return chr_aug.augment(txt)[0]


class GPTAugmentation(BaseAugmentation):
    """
    GPT-2-based text generation wrapper
    """
    def __init__(self):
        from transformers import pipeline
        super().__init__()
        self.gpt_aug = pipeline('text-generation', model='gpt2')

    def augment(self, txt: str) -> str:
        return self.gpt_aug(txt, max_length=20, num_return_sequences=1)[0]['generated_text']


class OPTAugmentation(BaseAugmentation):
    """
    OPT-based text generation wrapper
    """
    def __init__(self):
        from transformers import pipeline
        super().__init__()
        self.opt_aug = pipeline('text-generation', model='facebook/opt-125m')

    def augment(self, txt: str) -> str:
      return self.opt_aug(txt, max_length=20, do_sample=True)[0]['generated_text']


class ParaAugmentation(BaseAugmentation):
    """
    Paraphrase-based text generation wrapper
    """
    def __init__(self):
        from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
        super().__init__()
        self.para_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
        self.para_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws")
    def augment(self, txt: str) -> str:
        text =  "paraphrase: " + txt + " </s>"
        encoding = self.para_tokenizer.encode_plus(text, pad_to_max_length=True, return_tensors="pt")
        input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]


        outputs = self.para_model.generate(
          input_ids=input_ids, attention_mask=attention_masks,
          max_length=256,
          do_sample=True,
          top_k=120,
          top_p=0.95,
          early_stopping=True,
          num_return_sequences=1
          )
        for output in outputs:
            new_txt = self.para_tokenizer.decode(output, skip_special_tokens=True,clean_up_tokenization_spaces=True)                
            return new_txt


class BackTranslationAugmentation(BaseAugmentation):
    """
    Back-Translation-based text generation wrapper
    """
    def __init__(self):
        import nlpaug.augmenter.word as naw
        import random
        super().__init__()
        lang = [['Helsinki-NLP/opus-mt-en-es', 'Helsinki-NLP/opus-mt-es-en'],
        	['Helsinki-NLP/opus-mt-en-fr', 'Helsinki-NLP/opus-mt-fr-en'],
          ['Helsinki-NLP/opus-mt-en-de', 'Helsinki-NLP/opus-mt-de-en'],
          ['Helsinki-NLP/opus-mt-en-zh', 'Helsinki-NLP/opus-mt-zh-en'],
          ['Helsinki-NLP/opus-mt-en-it', 'Helsinki-NLP/opus-mt-it-en'],
          ['Helsinki-NLP/opus-mt-en-ru', 'Helsinki-NLP/opus-mt-ru-en']]
        choice = random.sample(range(6), 1)[0]
        self.back_aug = naw.BackTranslationAug(from_model_name=lang[choice][0], to_model_name=lang[choice][1]) # cpu

    def augment(self, txt: str) -> str:
        return self.back_aug.augment(txt)[0]
