import os
import json
import string
from typing import List

import numpy as np

from augmentex.base import BaseAug


RUSSIAN_VOCAB = [
    "а", "б", "в", "г", "д", "е", "ё", "ж", "з", "и", "й",
    "к", "л", "м", "н", "о", "п", "р", "с", "т", "у", "ф",
    "х", "ц", "ч", "ш", "щ", "ъ", "ы", "ь", "э", "ю", "я"]


class CharAug(BaseAug):
    """Augmentation at the character level.
    """

    def __init__(self, aug_rate: float = 0.3, min_aug: int = 1, max_aug: int = 5, mult_num: int = 5) -> None:
        """
        Args:
            aug_rate (float, optional): Percentage of the phrase to which augmentations will be applied. Defaults to 0.3.
            min_aug (int, optional): The minimum amount of augmentation. Defaults to 1.
            max_aug (int, optional): The maximum amount of augmentation. Defaults to 5.
            mult_num (int, optional): Maximum repetitions of characters. Defaults to 5.
        """
        super().__init__(min_aug=min_aug, max_aug=max_aug)
        dir_path = os.path.dirname(os.path.abspath(__file__))

        with open(os.path.join(dir_path, "static_data", "typos_chars_ru_en_digits.json")) as f:
            self.typo_dict = json.load(f)
        with open(os.path.join(dir_path, "static_data", "orfo_chars_ru.json")) as f:
            self.orfo_dict = json.load(f)
        with open(os.path.join(dir_path, "static_data", "shift_ru_en_digits.json")) as f:
            self.shift_dict = json.load(f)

        self.mult_num = mult_num
        self.aug_rate = aug_rate
        self.__actions = ["shift", "orfo", "typo",
                          "delete", "multiply", "swap", "insert"]

    @property
    def actions_list(self) -> List[str]:
        """
        Returns:
            List[str]: A list of possible methods.
        """
        return self.__actions

    def _typo(self, char: str) -> str:
        """A method that simulates a typo by an adjacent key.

        Args:
            char (str): A symbol from the word.

        Returns:
            str: A new symbol.
        """
        typo_char = np.random.choice(self.typo_dict.get(char, [char]))
        return typo_char

    def _shift(self, char: str) -> str:
        """Changes the case of the symbol.

        Args:
            char (str): A symbol from the word.

        Returns:
            str: The same character but with a different case.
        """
        shift_char = self.shift_dict.get(char, char)
        return shift_char

    def _orfo(self, char: str) -> str:
        """Changes the symbol depending on the error statistics.

        Args:
            char (str): A symbol from the word.

        Returns:
            str: A new symbol.
        """
        orfo_char = np.random.choice(
            RUSSIAN_VOCAB, p=self.orfo_dict.get(
                char, [1 / 33 for el in RUSSIAN_VOCAB])
        )
        return orfo_char

    def _delete(self) -> str:
        """Deletes a random character.

        Returns:
            str: Empty string.
        """
        return ""

    def _insert(self, char: str) -> str:
        """Inserts a random character.

        Args:
            char (str): A symbol from the word.

        Returns:
            str: A symbol + new symbol.
        """
        return char + np.random.choice(RUSSIAN_VOCAB)

    def _multiply(self, char: str) -> str:
        """Repeats a randomly selected character.

        Args:
            char (str): A symbol from the word.

        Returns:
            str: A symbol from the word matmul n times.
        """
        n = np.random.randint(1, self.mult_num)
        return char * n

    def _clean_punc(self, text: str) -> str:
        """Clears the text from punctuation.

        Args:
            text (str): Original text.

        Returns:
            str: Text without punctuation.
        """
        return text.translate(str.maketrans("", "", string.punctuation))

    def clean_punc_batch(self, batch, prob):
        aug_batch = batch.copy()
        aug_idxs = self.aug_indexing(aug_batch, prob)
        for idx in aug_idxs:
            aug_batch[idx] = self._clean_punc(aug_batch[idx])
        return aug_batch

    def aug_batch(self, batch, aug_prob=0.3, clean_punc=0.3, action=None):
        aug_batch = batch.copy()
        aug_batch = self.clean_punc_batch(aug_batch, clean_punc)
        aug_idxs = self.aug_indexing(aug_batch, aug_prob)
        for idx in aug_idxs:
            if action:
                aug_batch[idx] = self.augment(aug_batch[idx], action)
            else:
                aug_batch[idx] = self.random(aug_batch[idx])
        return aug_batch

    def augment(self, text, action):
        typo_text_arr = list(text)
        aug_idxs = self.aug_indexing(typo_text_arr, self.aug_rate, clip=True)
        for idx in aug_idxs:
            if action == "typo":
                typo_text_arr[idx] = self._typo(typo_text_arr[idx])
            elif action == "shift":
                typo_text_arr[idx] = self._shift(typo_text_arr[idx])
            elif action == "delete":
                typo_text_arr[idx] = self._delete()
            elif action == "insert":
                typo_text_arr[idx] = self._insert(typo_text_arr[idx])
            elif action == "orfo":
                typo_text_arr[idx] = self._orfo(typo_text_arr[idx])
            elif action == "multiply":
                typo_text_arr[idx] = self._multiply(typo_text_arr[idx])
            elif action == "swap":
                sw = max(0, idx - 1)
                typo_text_arr[sw], typo_text_arr[idx] = (
                    typo_text_arr[idx],
                    typo_text_arr[sw],
                )
            else:
                raise NameError(
                    """These type of augmentation is not available, please try TypoAug.actions_list() to see
                available augmentations"""
                )
        return "".join(typo_text_arr)

    def random(self, text):
        action = np.random.choice(self.__actions)
        new = self.augment(text, action)
        return new

    def __call__(self, text):
        return self.random(text)
