import random
from metrics import get_metric_dict


def word_level_splits(string, tokenizer=None):
    """Returns the indexes of the characters of a string as well as a list containing the start and lengths of words

    Keyword arguments:
    string -- the input text to be split
    tokenizer -- unused in this method (default=None)

    This method splits a string in word, by returning a list of the index of whitespaces in a string.
    If tokenizer is given, we use the tokenizer instead of simple whitespace tokenization.
    Those indexes are associated to the length in character of the word.
    By applying a perturbation function to that list of delimitation we can perturb a text on word-level.
    """

    # Split words on white spaces, save indexes of split
    splits = [0] + [i for i in range(len(string)) if string[i] == " "]
    # Record length of words in chars, associate to index of splits
    splits_and_lengths = [(splits[i], splits[i + 1] - splits[i]) for i in range(len(splits) - 1)] + [
        (splits[-1], len(string) - splits[-1])
    ]
    return splits_and_lengths


def subword_level_splits(string, tokenizer):
    """Returns the indexes of the characters of a string as well as a list containing the start and lengths of subwords

    Keyword arguments:
    string -- the input text to be split
    tokenizer -- the tokenizer used to produce the subwords

    This method splits a string in subword, by using a provided tokenizer and recording the index, in character, of the splits
    Those indexes are associated to the length in character of the subword.
    By applying a perturbation function to that list of delimitation we can perturb a text on subword-level.
    Some clean up is necessary to obtain exactly the same characters.
    """
    # Tokenize the string into token ids
    id_string = tokenizer.encode_plus(string, add_special_tokens=False)["input_ids"]
    # Bring back to list of tokens
    subword_string = tokenizer.convert_ids_to_tokens(id_string)

    splits_and_lengths = []
    total_len = 0
    for word in subword_string:
        splits_and_lengths.append((total_len, len(word)))
        total_len += len(word)

    return splits_and_lengths


def char_level_splits(string, tokenizer=None):
    """Returns the indexes of the characters of a string as well as a list containing the start and lengths of characters

    Keyword arguments:
    string -- the input text to be split
    tokenizer -- unused in this method (default=None)

    This method respect the format output format of both subword_level_splits and word_level_splits.
    It, trivially, is entirely composed of spans of length 1.
    """

    splits_and_lengths = []
    total_len = 0
    for char in string:
        splits_and_lengths.append((total_len, 1))
        total_len += 1
    return splits_and_lengths


granularity_to_splitter = {
    "word": word_level_splits,
    "subword": subword_level_splits,
    "char": char_level_splits,
}


# No actual shuffle
def benchmark_shuffle(string, granularity="word", tokenizer=None, rho=None):
    """Returns the unperturbed string as well as the metrics.

    Keyword arguments:
    string -- the input text to be perturbed
    granularity -- unused in this method (default='word')
    tokenizer -- only used to calculate the compression rate in this method
    rho -- unused in this method (default=None)
    """

    metric_dict = get_metric_dict(string, string, tokenizer)

    return string, metric_dict


def full_shuffle(string, granularity="word", tokenizer=None, rho=None):
    """Returns the string fully shuffled on the set granularity as well as the metrics.

    Keyword arguments:
    string -- the input text to be split
    granularity -- the granularity of the perturbation (default='word')
    tokenizer -- the tokenizer containing the subword vocabulary (default=None)
    rho -- unused in this method (default=None)
    """

    # Gather all splits at chosen granularity
    splits_and_lengths = granularity_to_splitter[granularity](string, tokenizer)
    # Random shuffling of the splits
    random.shuffle(splits_and_lengths)
    # Rebuilding of string
    new_string = []
    for split, length in splits_and_lengths:
        new_string.extend(string[split : split + length])

    new_string = "".join(new_string)
    metric_dict = get_metric_dict(string, new_string, tokenizer)

    return new_string, metric_dict


def neighbour_flip_shuffle(string, granularity="word", tokenizer=None, rho=0.5):
    """Returns the string perturbed with a neighbour flip with probability RHO as well as the metrics.

    Keyword arguments:
    string -- the input text to be split
    granularity -- the granularity of the perturbation (default='word')
    tokenizer -- the tokenizer containing the subword vocabulary (default=None)
    rho -- probability of performing a neighbour flip operation (default=0.5)
    """
    splits_and_lengths = granularity_to_splitter[granularity](string, tokenizer)

    new_order = []
    reserve_word = splits_and_lengths[0]
    for split_and_length in splits_and_lengths[1:]:
        if random.random() < rho:
            new_order.append(reserve_word)
            reserve_word = split_and_length
        else:
            new_order.append(split_and_length)
    new_order.append(reserve_word)

    # Rebuilding of string
    new_string = []
    for split, length in new_order:
        new_string.extend(string[split : split + length])

    new_string = "".join(new_string)
    metric_dict = get_metric_dict(string, new_string, tokenizer)

    return new_string, metric_dict


def phrase_shuffle(string, granularity="word", tokenizer=None, rho=0.5):
    """Returns the string perturbed with a phrase shuffle with probability RHO as well as the metrics.

    Keyword arguments:
    string -- the input text to be split
    granularity -- the granularity of the perturbation (default='word')
    tokenizer -- the tokenizer containing the subword vocabulary (default=None)
    rho -- probability of performing a token being added to the current bucket (default=0.5)
    """
    splits_and_lengths = granularity_to_splitter[granularity](string, tokenizer)

    new_order = []
    bucket = [splits_and_lengths[0]]
    for split_and_length in splits_and_lengths[1:]:
        if random.random() < rho:
            bucket.append(split_and_length)
        else:
            new_order.append(bucket)
            bucket = [split_and_length]
    new_order.append(bucket)
    random.shuffle(new_order)
    new_order = [split_and_length for bucket in new_order for split_and_length in bucket]

    # Rebuilding of string
    new_string = []
    for split, length in new_order:
        new_string.extend(string[split : split + length])

    new_string = "".join(new_string)
    metric_dict = get_metric_dict(string, new_string, tokenizer)

    return new_string, metric_dict
