from typing import Any, List
import random
import copy
import torch
import numpy as np
from .denoising_processor import DenoisingProcessor


class RandomSkipProcessor:
    def __init__(self, tokenizer, processor, prob: float):
        self.prob = prob
        self.processor = processor
        # self.skip_prompt = torch.LongTensor(tokenizer.encode("<skip>", add_special_tokens=False))
        self.skip_prompt = torch.LongTensor([50262] + [50264] * 100 + [2])

    def __call__(self, tokens: torch.Tensor, reset: bool) -> Any:
        source, target = self.processor(tokens)
        yield (source, target), reset
        # repeat the previous item, but replace encoder input
        if random.random() < self.prob:
            source = self.skip_prompt
            yield (source, target), False

class SplitProcessor:
    def __init__(self, config, tokenizer):
        self.tokenizer = tokenizer
        self.context_stride_min = config.context_stride_min
        self.context_stride_max = config.context_stride_max
        self.context_length = config.context_length

    def __call__(self, tokens: List) -> Any:
        start_pos = 0
        reset = True
        while True:
            sub_tokens = tokens[start_pos : start_pos + self.context_length]
            sub_tokens = [self.tokenizer.bos_token_id] + sub_tokens + [self.tokenizer.eos_token_id]
            sub_tokens = torch.LongTensor(sub_tokens)
            yield sub_tokens, reset
            
            if (start_pos + self.context_length) >= len(tokens):
                break
            else:
                start_pos += np.random.randint(self.context_stride_min, self.context_stride_max)
                reset = False

class PretrainProcessor:
    def __init__(self, tokenizer, config):
        self.split_processor = SplitProcessor(config, tokenizer)
        processor = DenoisingProcessor(
            tokenizer,
            config.poisson_lambda,
            config.mask_ratio,
            config.random_ratio,
            config.replace_length,
            config.permute_sentence_ratio,
            config.insert_ratio,
            config.rotate_ratio,
        )
        self.processor = RandomSkipProcessor(
            tokenizer, processor, prob=config.skip_context_prob
        )

    def __call__(self, tokens: torch.Tensor):
        for sub_tokens, reset in self.split_processor(tokens):
            for item, reset in self.processor(sub_tokens, reset):
                yield item, reset
