import math
import torch
import numpy as np
from transformers import AutoTokenizer


def get_whole_word_mask(tokenizer: AutoTokenizer):
    def is_beginning_of_word(i):
        tok = tokenizer.convert_ids_to_tokens(i)
        if tok.startswith("madeupword"):
            return True
        if tok in tokenizer.special_tokens_map.values():
            return True
        if tok.startswith("Ġ"):
            return True
        return False

    mask_whole_words = torch.ByteTensor(
        list(map(is_beginning_of_word, range(len(tokenizer))))
    )
    return mask_whole_words


class DenoisingProcessor:
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        poisson_lambda: int = 3,
        mask_ratio: float = 0.3,
        random_ratio: float = 0.1,
        replace_length: int = 1,
        permute_sentence_ratio: float = 1.0,
        insert_ratio: float = 0.0,
        rotate_ratio: float = 0.0,
    ):
        self.tokenizer = tokenizer
        self.random_ratio = random_ratio
        self.mask_ratio = mask_ratio
        self.replace_length = replace_length
        self.mask_whole_word = get_whole_word_mask(tokenizer)
        self.full_stop_index = tokenizer.vocab["."]
        self.mask_idx = tokenizer.mask_token_id
        self.permute_sentence_ratio = permute_sentence_ratio
        self.insert_ratio = insert_ratio
        self.rotate_ratio = rotate_ratio

        # span-poisson distribution in BART
        self.mask_span_distribution = None
        lambda_to_the_k = 1
        e_to_the_minus_lambda = math.exp(-poisson_lambda)
        k_factorial = 1
        ps = []
        for k in range(0, 128):
            ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
            lambda_to_the_k *= poisson_lambda
            k_factorial *= k + 1
            if ps[-1] < 0.0000001:
                break
        ps = torch.FloatTensor(ps)
        self.mask_span_distribution = torch.distributions.Categorical(ps)

    def word_starts(self, source):
        if self.mask_whole_word is not None:
            is_word_start = self.mask_whole_word.gather(0, source)
        else:
            is_word_start = torch.ones(source.size())
        is_word_start[0] = 0
        is_word_start[-1] = 0
        return is_word_start

    def permute_sentences(self, source: torch.Tensor, p=1.0):
        full_stops = source == self.full_stop_index
        # Pretend it ends with a full stop so last span is a sentence
        full_stops[-2] = 1

        # Tokens that are full stops, where the previous token is not
        sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
        result = source.clone()

        num_sentences = sentence_ends.size(0)
        num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
        substitutions = torch.randperm(num_sentences)[:num_to_permute]
        ordering = torch.arange(0, num_sentences)
        ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]

        # Ignore <bos> at start
        index = 1
        for i in ordering:
            sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
            result[index : index + sentence.size(0)] = sentence
            index += sentence.size(0)
        return result

    def add_whole_word_mask(self, source: torch.Tensor, p):
        is_word_start = self.word_starts(source)
        num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
        num_inserts = 0
        if num_to_mask == 0:
            return source

        if self.mask_span_distribution is not None:
            lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))

            # Make sure we have enough to mask
            cum_length = torch.cumsum(lengths, 0)
            while cum_length[-1] < num_to_mask:
                lengths = torch.cat(
                    [
                        lengths,
                        self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
                    ],
                    dim=0,
                )
                cum_length = torch.cumsum(lengths, 0)

            # Trim to masking budget
            i = 0
            while cum_length[i] < num_to_mask:
                i += 1
            lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
            num_to_mask = i + 1
            lengths = lengths[:num_to_mask]

            # Handle 0-length mask (inserts) separately
            lengths = lengths[lengths > 0]
            num_inserts = num_to_mask - lengths.size(0)
            num_to_mask -= num_inserts
            if num_to_mask == 0:
                return self.add_insertion_noise(source, num_inserts / source.size(0))
            assert (lengths > 0).all()
        else:
            lengths = torch.ones((num_to_mask,)).long()

        assert is_word_start[-1] == 0
        word_starts = is_word_start.nonzero(as_tuple=False)
        indices = word_starts[
            torch.randperm(word_starts.size(0))[:num_to_mask]
        ].squeeze(1)
        mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio

        source_length = source.size(0)
        assert source_length - 1 not in indices
        to_keep = torch.ones(source_length, dtype=torch.bool)
        is_word_start[
            -1
        ] = 255  # acts as a long length, so spans don't go over the end of doc
        if self.replace_length == 0:
            to_keep[indices] = 0
        else:
            # keep index, but replace it with [MASK]
            source[indices] = self.mask_idx
            source[indices[mask_random]] = torch.randint(
                1, len(self.tokenizer), size=(mask_random.sum(),)
            )

        if self.mask_span_distribution is not None:
            assert len(lengths.size()) == 1
            assert lengths.size() == indices.size()
            lengths -= 1
            while indices.size(0) > 0:
                assert lengths.size() == indices.size()
                lengths -= is_word_start[indices + 1].long()
                uncompleted = lengths >= 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                lengths = lengths[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.mask_idx
                    source[indices[mask_random]] = torch.randint(
                        1, len(self.tokenizer), size=(mask_random.sum(),)
                    )
        else:
            # A bit faster when all lengths are 1
            while indices.size(0) > 0:
                uncompleted = is_word_start[indices + 1] == 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.mask_idx
                    source[indices[mask_random]] = torch.randint(
                        1, len(self.tokenizer), size=(mask_random.sum(),)
                    )
                assert source_length - 1 not in indices

        import pdb; pdb.set_trace()
        source = source[to_keep]

        if num_inserts > 0:
            source = self.add_insertion_noise(source, num_inserts / source.size(0))

        return source

    def add_permuted_noise(self, tokens: torch.Tensor, p):
        num_words = len(tokens)
        num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
        substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
        tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
        return tokens

    def add_rolling_noise(self, tokens: torch.Tensor):
        offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
        tokens = torch.cat(
            (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), dim=0,
        )
        return tokens

    def add_insertion_noise(self, tokens: torch.Tensor, p):
        if p == 0.0:
            return tokens

        num_tokens = len(tokens)
        n = int(math.ceil(num_tokens * p))

        noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
        noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
        noise_mask[noise_indices] = 1
        result = torch.LongTensor(n + len(tokens)).fill_(-1)

        num_random = int(math.ceil(n * self.random_ratio))
        result[noise_indices[num_random:]] = self.mask_idx
        result[noise_indices[:num_random]] = torch.randint(
            low=1, high=len(self.tokenizer), size=(num_random,)
        )

        result[~noise_mask] = tokens

        assert (result >= 0).all()
        return result

    def __call__(self, tokens: torch.Tensor):
        source, target = tokens, tokens.clone()
        if self.permute_sentence_ratio > 0.0:
            source = self.permute_sentences(source, self.permute_sentence_ratio)

        if self.mask_ratio > 0:
            source = self.add_whole_word_mask(source, self.mask_ratio)

        if self.insert_ratio > 0:
            source = self.add_insertion_noise(source, self.insert_ratio)

        if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
            source = self.add_rolling_noise(source)

        assert (source >= 0).all().item()
        assert (source <= len(self.tokenizer)).all()
        assert source[0].item() == self.tokenizer.bos_token_id
        assert source[-1].item() == self.tokenizer.eos_token_id

        # assert (source[1:-1] >= 1).all().item()

        return source, target
