import os
import re
import json
from typing import List

import numpy as np

from augmentex.base import BaseAug


class WordAug(BaseAug):
    """Augmentation at the level of words."""

    def __init__(
        self,
        min_aug: int = 1,
        max_aug: int = 5,
        unit_prob: float = 0.3,
    ) -> None:
        """
        Args:
            min_aug (int, optional): The minimum amount of augmentation. Defaults to 1.
            max_aug (int, optional): The maximum amount of augmentation. Defaults to 5.
            unit_prob (float, optional): Percentage of the phrase to which augmentations will be applied. Defaults to 0.3.
        """
        super().__init__(min_aug=min_aug, max_aug=max_aug)
        dir_path = os.path.dirname(os.path.abspath(__file__))

        with open(os.path.join(dir_path, "static_data", "stopwords_ru.json")) as f:
            self.stopwords = json.load(f)
        with open(os.path.join(dir_path, "static_data", "orfo_ru_words.json")) as f:
            self.orfo_words = json.load(f)

        self.unit_prob = unit_prob
        self.__actions = ["replace", "delete", "swap", "stopword", "reverse"]

    @property
    def actions_list(self) -> List[str]:
        """
        Returns:
            List[str]: A list of possible methods.
        """

        return self.__actions

    def _reverse_case(self, word: str) -> str:
        """Changes the case of the first letter to the reverse.

        Args:
            word (str): The initial word.

        Returns:
            str: A word with a different case of the first letter.
        """
        if len(word):
            if word[0].isupper():
                word = word.lower()
            else:
                word = word.capitalize()

        return word

    def _replace(self, word: str, rng: np.random.default_rng) -> str:
        """Replaces a word with the correct spelling with a word with spelling errors.

        Args:
            word (str): A word with the correct spelling.

        Returns:
            str: A misspelled word.
        """
        word = re.findall("[а-яА-ЯёЁa-zA-Z0-9']+|[.,!?;]+", word)
        word_probas = self.orfo_words.get(word[0], [[word[0]], [1.0]])
        word[0] = rng.choice(word_probas[0], p=word_probas[1])

        return "".join(word)

    def _delete(self) -> str:
        """Deletes a random word.

        Returns:
            str: Empty string.
        """

        return ""

    def _stopword(self, word: str, rng: np.random.default_rng) -> str:
        """Adds a stop word before the word.

        Args:
            word (str): Just word.

        Returns:
            str: Stopword + word.
        """
        stopword = rng.choice(self.stopwords)

        return " ".join([stopword, word])

    def augment(self, text: str, seed: int = 42, rng: np.random.default_rng = None, action: str = None) -> str:
        if rng is None:
            rng = np.random.default_rng(seed)
        if action is None:
            action = rng.choice(self.__actions)

        aug_sent_arr = text.split()
        aug_idxs = self._aug_indexing(aug_sent_arr, self.unit_prob, rng, clip=True)
        for idx in aug_idxs:
            if action == "delete":
                aug_sent_arr[idx] = self._delete()
            elif action == "reverse":
                aug_sent_arr[idx] = self._reverse_case(aug_sent_arr[idx])
            elif action == "swap":
                candidate_ids = list(range(len(aug_sent_arr)))
                candidate_ids.remove(idx)
                if len(candidate_ids) == 0:
                    continue
                swap_idx = rng.choice(candidate_ids)
                aug_sent_arr[swap_idx], aug_sent_arr[idx] = (
                    aug_sent_arr[idx],
                    aug_sent_arr[swap_idx],
                )
            elif action == "stopword":
                aug_sent_arr[idx] = self._stopword(aug_sent_arr[idx], rng)
            elif action == "replace":
                aug_sent_arr[idx] = self._replace(aug_sent_arr[idx], rng)
            else:
                raise NameError(
                    """These type of augmentation is not available, please check EDAAug.actions_list() to see
                available augmentations"""
                )

        return re.sub(" +", " ", " ".join(aug_sent_arr).strip())
