import numpy as np
import torch

def word_shuffle(tokenizer, x, k):   # slight shuffle such that |sigma[i]-i| <= k
    seq_length = x.size(1)
    batch_size = x.size(0)
    base = torch.arange(seq_length, dtype=torch.float).repeat(batch_size, 1)
    inc = (k+1) * torch.rand(x.size())
    inc[x == vocab.go] = 0     # do not shuffle the start sentence symbol
    inc[x == vocab.pad] = k+1  # do not shuffle end paddings
    _, sigma = (base + inc).sort(dim=1)
    return x[sigma, torch.arange(x.size(1))]

def word_drop(special_token_ids, x, p):     # drop words with probability p
    x_ = []
    lengths_ = []
    batch_size = x.size(0)
    #special_token_ids = get_special_token_ids(tokenizer)
    for i in range(batch_size):
        words = x[i, :].tolist()
        keep = np.random.rand(len(words)) > p
        
        # do not drop any of the special symbols
        for i in range(len(words)):
            if words[i] in special_token_ids:
                keep[i] = True
                
        sent = [w for j, w in enumerate(words) if keep[j]]
        lengths_.append(len(sent))
        sent += [special_token_ids[0]] * (len(words)-len(sent))
        x_.append(sent)
        
    new_x = torch.LongTensor(x_).contiguous().to(x.device)
    new_l = torch.LongTensor(lengths_).contiguous().to(x.device)
    _, sorted_indices = new_l.sort(dim = 0, descending = True)
    
    # remember original sorting
    original_indices = torch.zeros(len(new_l))
    for i in range(len(new_l)):
        original_indices[sorted_indices[i]] = i
    original_indices = original_indices.long()
    
    return new_x[sorted_indices,:], new_l[sorted_indices], original_indices
           

def word_blank(tokenizer, x, p):     # blank words with probability p
    blank = (torch.rand(x.size(), device=x.device) < p) & \
        (x != vocab.go) & (x != vocab.pad)
    x_ = x.clone()
    x_[blank] = vocab.blank
    return x_

def word_substitute(tokenizer, x, p):     # substitute words with probability p
    keep = (torch.rand(x.size(), device=x.device) > p) | \
        (x == vocab.go) | (x == vocab.pad)
    x_ = x.clone()
    x_.random_(0, vocab.size)
    x_[keep] = x[keep]
    return x_

def get_special_token_ids(tokenizer):
    special_tokens = ["[PAD]", "<unk>", "<SOS>", "<EOS>"]
    return [tokenizer.token_to_id(st) for st in special_tokens]

def get_vocab_size(tokenizer):
    return tokenizer._tokenizer.get_vocab_size()

def noisy(special_token_ids, x, drop_prob):#, blank_prob, sub_prob, shuffle_dist):
    #if shuffle_dist > 0:
        #x = word_shuffle(tokenizer, x, shuffle_dist)
    if drop_prob > 0:
        x, lens, indices = word_drop(special_token_ids, x, drop_prob)
    #if blank_prob > 0:
        #x = word_blank(tokenizer, x, blank_prob)
    #if sub_prob > 0:
        #x = word_substitute(tokenizer, x, sub_prob)
    return x, lens, indices
