import math
import random
from nltk.tokenize import word_tokenize
from nltk.corpus import wordnet, stopwords
from nltk import pos_tag

class EDABuilder:
  """ Based off of implementation of the original EDA algorithm (Wei and Zhou 2019)"""
  def __init__(self, threshold):
    # super().__init__(args)
    self.name = 'eda'
    self.prob = threshold
    self.stopwords = stopwords.words('english')

  def _swap_word(self, new_words):
    random_idx_1 = random.randint(0, len(new_words) - 1)
    random_idx_2 = random_idx_1
    counter = 0
    while random_idx_2 == random_idx_1:
      random_idx_2 = random.randint(0, len(new_words) - 1)
      counter += 1
      if counter > 3:
        return new_words
    new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
    return new_words

  @staticmethod
  def _get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
      for lemma in syn.lemmas():
        synonym = lemma.name().replace("_", " ").replace("-", " ").lower()
        synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
        synonyms.add(synonym)
    if word in synonyms:
      synonyms.remove(word)
    return list(synonyms)

  def _add_word(self, new_words):
    synonyms = []
    counter = 0
    while len(synonyms) < 1:
      random_word_list = list([word for word in new_words if word not in self.stopwords])
      if len(random_word_list) < 1:
        return new_words
      random_word = random_word_list[random.randint(0, len(random_word_list) - 1)]
      synonyms = self._get_synonyms(random_word)
      counter += 1
      if counter >= 10:
        synonyms.append(random_word)
    random_synonym = random.choice(synonyms)
    random_idx = random.randint(0, len(new_words) - 1)
    new_words.insert(random_idx, random_synonym)
    return new_words

  def swap_tokens(self, words):
    new_words = words.copy()
    num_swaps = math.ceil(len(words) * self.prob)
    for _ in range(num_swaps):
      new_words = self._swap_word(new_words)
    return " ".join(new_words)

  def delete_tokens(self, words):
    new_words = list()
    for word in words:
      if random.uniform(0, 1) > self.prob:
        new_words.append(word)
    # if all words are deleted, just return a random word
    if len(new_words) == 0:
        return random.choice(words)
    return " ".join(new_words)

  def insert_tokens(self, words):
    num_insertions = math.ceil(len(words) * self.prob)
    new_words = words.copy()
    for _ in range(num_insertions):
      new_words = self._add_word(new_words)
    return " ".join(new_words)

  def augment(self, utterance):
    words = utterance.lower().split()
    if len(words) == 1:
      return [utterance]

    swapped = self.swap_tokens(words)
    deletion = self.delete_tokens(words)
    insertion = self.insert_tokens(words)
    return [swapped, deletion, insertion]
	
# class WordNetBuilder(BaseBuilder):
#   def __init__(self, args, encoder):
#     super().__init__(args)
#     self.name = 'wordnet'
#     self.encoder = encoder
#     self.pos_mapper = { 'noun': ('N', wordnet.NOUN), 
#                         'verb': ('V', wordnet.VERB),
#                         'adjective': ('J', wordnet.ADJ) }
#     self.threshold = args.threshold

#   def geometric(self, data):
#     """ Input is a list, and output is a synonym drawn from a geometric distribution """
#     data = np.array(data)
#     first_trial = np.random.geometric(p=0.5, size=data.shape[0]) == 1  
#     top_option = data[first_trial]  # Capture success after first trial
#     return top_option

#   def swap_synonym(self, text, tokens, part_of_speech):
#     pos_char, pos_obj = self.pos_mapper[part_of_speech]
#     words = [[i, x] for i, x, y in tokens if y.startswith(pos_char)]
#     if len(words) > 0:
#       selected_word = random.choice(words)  # select a word to replace
#     else:
#       return text

#     matches = wordnet.synsets(selected_word[1], pos_obj)  # Return matches for the pos
#     synonyms = list(set(chain.from_iterable([syn.lemma_names() for syn in matches])))
#     synonyms_ = []  # Synonyms with no underscores goes here
#     for w in synonyms:
#       if '_' not in w and w != selected_word[1]:
#         synonyms_.append(w)  # Remove words with underscores
#     if len(synonyms_) >= 1:
#       synonym = self.geometric(data=synonyms_).tolist()
#       if synonym:  # There is a synonym
#         text[int(selected_word[0])] = synonym[0].lower()  # Take the first success
#     return text
  
#   def augment(self, utterance):
#     augmentations = []
#     for _ in range(3):
#       text = word_tokenize(utterance.lower())
#       tokens = [[i, x, y] for i, (x, y) in enumerate(pos_tag(text))]  # Convert tuple to list
#       for part_of_speech in ['noun', 'verb', 'adjective']:
#         if random.random() > self.threshold:
#           text = self.swap_synonym(text, tokens, part_of_speech)
#       augmentations.append(" ".join(text))
#     return augmentations