from typing import Any, List
import random
import copy
import torch
import torch.nn.functional as F
import numpy as np
from .denoising_processor import DenoisingProcessor


class InfillingSkipProcessor:
    def __init__(self, config, tokenizer):
        self.recent_k = config.keep_recent_k
        self.tokenizer = tokenizer
        self.processor = DenoisingProcessor(
            tokenizer,
            config.poisson_lambda,
            config.mask_ratio,
            config.random_ratio,
            config.replace_length,
            config.permute_sentence_ratio,
            config.insert_ratio,
            config.rotate_ratio,
        )

        self.prob = config.skip_context_prob
        self.ignore_seen_words = config.ignore_seen_words

        # self.skip_prompt = torch.LongTensor(tokenizer.encode("<skip>", add_special_tokens=False))
        self.skip_prompt = torch.LongTensor([0] + [50262] * 8 + [50264] * 128 + [2])
        self.history = []

    def __call__(self, tokens: torch.Tensor, reset: bool) -> Any:
        if reset:
            self.history = []

        source, decoder_inputs = self.processor(tokens)

        # store most recent k documents
        self.history.append(decoder_inputs.clone())
        self.history = self.history[-self.recent_k:]

        decoder_inputs, labels = self.get_labels(source, decoder_inputs)
        
        yield (source, decoder_inputs, labels), reset

        # randomly select a previous history
        while random.random() < self.prob:
            decoder_inputs = random.choice(self.history)
            decoder_inputs, labels = self.get_labels(None, decoder_inputs)

            yield (self.skip_prompt, decoder_inputs, labels), False

    def get_labels(self, source, decoder_inputs):
        labels = decoder_inputs
        decoder_inputs = F.pad(decoder_inputs, (1, 0), "constant", self.tokenizer.eos_token_id)[:-1]

        if self.ignore_seen_words and source is not None:
            words = set(source.tolist())
            for i, w in enumerate(decoder_inputs.tolist()):
                if w in words:
                    labels[i] = -100

        return decoder_inputs, labels

class ContinuationProcessor:
    def __init__(self, config, tokenizer) -> None:
        self.config = config
        self.tokenizer = tokenizer
        self.start_prompt = torch.LongTensor([0] + [50263] * 8 + [50264] * 128 + [2])
        self.prev_tokens = None

    def __call__(self, tokens: torch.Tensor, reset: bool) -> Any:
        """tokens: after split"""
        # store previous tokens
        if reset:
            self.prev_tokens = None

        if self.prev_tokens is None:
            labels = tokens[1:]
            decoder_inputs = tokens[:-1]
            yield (self.start_prompt, decoder_inputs, labels), reset
        else:
            labels = tokens[1:]
            decoder_inputs = tokens[:-1]
            yield (self.prev_tokens, decoder_inputs, labels), reset

        self.prev_tokens = tokens

class SplitProcessor:
    def __init__(self, config, tokenizer):
        self.tokenizer = tokenizer
        self.context_stride_min = config.context_stride_min
        self.context_stride_max = config.context_stride_max
        self.context_length = config.context_length

    def __call__(self, tokens: List) -> Any:
        start_pos = 0
        reset = True
        while True:
            sub_tokens = tokens[start_pos : start_pos + self.context_length]
            sub_tokens = [self.tokenizer.bos_token_id] + sub_tokens + [self.tokenizer.eos_token_id]
            sub_tokens = torch.LongTensor(sub_tokens)
            yield sub_tokens, reset

            if (start_pos + self.context_length) >= len(tokens):
                break
            else:
                start_pos += np.random.randint(self.context_stride_min, self.context_stride_max)
                reset = False


class PretrainProcessor:
    def __init__(self, tokenizer, config):
        self.split_processor = SplitProcessor(config, tokenizer)
        self.processors = [InfillingSkipProcessor(config, tokenizer), 
                           ContinuationProcessor(config, tokenizer)]
        self.processor_probs = config.processor_probs

    def __call__(self, tokens: torch.Tensor):
        processor = np.random.choice(self.processors)
        for sub_tokens, reset in self.split_processor(tokens):
            for item, output_reset in processor(sub_tokens, reset):
                yield item, output_reset
