#encoding:utf-8
import numpy as np
import random

class Augmentator(object):
    def __init__(self,is_train_mode = True, proba = 0.5):
        self.mode = is_train_mode
        self.proba = proba
        self.augs = []
        self._reset()


    def _reset(self):
        self.augs.append(lambda text: self._shuffle(text))
        self.augs.append(lambda text: self._dropout(text,p = 0.5))


    def _shuffle(self, text):
        text = np.random.permutation(text.strip().split())
        return ' '.join(text)


    def _dropout(self, text, p=0.5):
        # random delete some text
        text = text.strip().split()
        len_ = len(text)
        indexs = np.random.choice(len_, int(len_ * p))
        for i in indexs:
            text[i] = ''
        return ' '.join(text)

    def __call__(self,text,aug_type):
        if 0 <= aug_type <= 2:
            pass
        if self.mode and  random.random() < self.proba:
            aug = random.choice(self.augs)
            text = aug(text)
        return text
