from typing import List
import copy
import os
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pyarrow as pa
import random

from transformers import BartTokenizerFast


def get_mask_tokens(tokens, mask_token_id, p: float = 0.15, lam: int = 3):
    new_tokens = []
    masked_tokens = []

    index = 0
    flag = False
    while index < len(tokens):
        if np.random.rand() < p and not flag:
            span_size = np.random.poisson(lam=lam)
            if span_size == 0:
                continue
            flag = True
            # skip the masked tokens
            new_tokens.append(mask_token_id)
            masked_tokens.append(tokens[index:index + span_size])
            index += span_size
        else:
            flag = False
            new_tokens.append(tokens[index])
            index += 1

    return new_tokens, masked_tokens


class TextContinuation:

    def __init__(self, tokenizer, context_length, max_position_embeddings, task_prompt_position):
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.max_position_embeddings = max_position_embeddings
        self.task_prompt_position = task_prompt_position
        self.start_prompt = tokenizer.encode("<start>", add_special_tokens=False)
        self.skip_prompt = tokenizer.encode("<skip>", add_special_tokens=False)
        self.task_prompt = tokenizer.encode("Sentence continuation", add_special_tokens=False)
        self.decoder_prompt = tokenizer.encode("<continued sentences>:", add_special_tokens=False)

    def process_document(self, document):
        context = copy.copy(self.start_prompt)

        task_prompt = [self.tokenizer.sep_token_id] + self.task_prompt + [self.tokenizer.sep_token_id]
        task_position_ids = list(np.arange(len(task_prompt)) + self.task_prompt_position)

        for sent_idx, sent in enumerate(document):
            encoder_context = [self.tokenizer.cls_token_id] + context + [self.tokenizer.sep_token_id]
            encoder_input_ids = encoder_context + task_prompt
            random_offset = np.random.randint(self.task_prompt_position - len(encoder_context))
            context_position_ids = list(np.arange(len(encoder_context)) + random_offset)

            encoder_position_ids = context_position_ids + task_position_ids
            decoder_input_ids = [self.tokenizer.bos_token_id] + self.decoder_prompt + sent
            decoder_output_ids = decoder_input_ids[1:] + [self.tokenizer.eos_token_id]

            random_offset = 0  # np.random.randint(self.max_position_embeddings - len(decoder_input_ids))
            decoder_position_ids = list(np.arange(len(decoder_input_ids)) + random_offset)

            item = {
                "task_prompt": task_prompt,
                "task_position_ids": task_position_ids,
                "encoder_input_ids": encoder_input_ids,
                "encoder_position_ids": encoder_position_ids,
                "decoder_input_ids": decoder_input_ids,
                "decoder_output_ids": decoder_output_ids,
                "decoder_position_ids": decoder_position_ids,
                "reset": False,
            }
            yield item

            context.extend(sent)
            context = context[-self.context_length:]


class TextInfilling:

    def __init__(self, tokenizer, context_length, max_position_embeddings, task_prompt_position, mask_token_prob, mask_possion_lambda):
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.max_position_embeddings = max_position_embeddings
        self.task_prompt_position = task_prompt_position
        self.mask_token_prob = mask_token_prob
        self.mask_possion_lambda = mask_possion_lambda
        self.mask_token_id = tokenizer.mask_token_id
        self.start_prompt = tokenizer.encode("<start>", add_special_tokens=False)
        self.skip_prompt = tokenizer.encode("<skip>", add_special_tokens=False)
        self.task_prompt = tokenizer.encode("Fill masked tokens", add_special_tokens=False)
        self.decoder_prompt = tokenizer.encode("<filled sentences>:", add_special_tokens=False)
        self.mask_prompt = tokenizer.encode("<Masked Sentences:>", add_special_tokens=False)

    def process_document(self, document):
        history_context = copy.copy(self.start_prompt)

        task_prompt = [self.tokenizer.sep_token_id] + self.task_prompt + [self.tokenizer.sep_token_id]
        task_position_ids = list(np.arange(len(task_prompt)) + self.task_prompt_position)

        for sent_idx, sent in enumerate(document):
            masked_sent, _ = get_mask_tokens(sent,
                                             self.mask_token_id,
                                             p=self.mask_token_prob,
                                             lam=self.mask_possion_lambda)
            encoder_context = history_context + self.mask_prompt + masked_sent
            encoder_context = encoder_context[-self.context_length:]
            encoder_context = [self.tokenizer.cls_token_id] + encoder_context + [self.tokenizer.sep_token_id]

            encoder_input_ids = encoder_context + task_prompt
            random_offset = np.random.randint(self.task_prompt_position - len(encoder_context))
            context_position_ids = list(np.arange(len(encoder_context)) + random_offset)
            encoder_position_ids = context_position_ids + task_position_ids

            # output the masked sentences
            decoder_input_ids = [self.tokenizer.bos_token_id] + self.decoder_prompt + sent
            decoder_output_ids = decoder_input_ids[1:] + [self.tokenizer.eos_token_id]
            random_offset = 0  # np.random.randint(self.max_position_embeddings - len(decoder_input_ids))
            decoder_position_ids = list(np.arange(len(decoder_input_ids)) + random_offset)

            item = {
                "task_prompt": task_prompt,
                "task_position_ids": task_position_ids,
                "encoder_input_ids": encoder_input_ids,
                "encoder_position_ids": encoder_position_ids,
                "decoder_input_ids": decoder_input_ids,
                "decoder_output_ids": decoder_output_ids,
                "decoder_position_ids": decoder_position_ids,
                "reset": False,
            }
            yield item

            # keep original sentences
            history_context.extend(sent)
            history_context = history_context[-self.context_length * 2:]


class RandomSkipProcessor:

    def __init__(self, tokenizer, processor, prob: float = 0.2, task_prompt_position: int = 960):
        self.tokenizer = tokenizer
        self.skip_prompt = self.tokenizer.encode("<skip>", add_special_tokens=False)
        self.processor = processor
        self.prob = prob
        self.task_prompt_position = task_prompt_position

    def process_document(self, document):
        previous_item = None
        for item in self.processor.process_document(document):
            # repeat the previous item, but replace encoder input
            if random.random() < self.prob and previous_item != None:
                skip_item = copy.copy(previous_item)
                encoder_context = [self.tokenizer.cls_token_id] + self.skip_prompt + [self.tokenizer.sep_token_id]
                skip_item["encoder_input_ids"] = encoder_context + skip_item["task_prompt"]
                random_offset = np.random.randint(self.task_prompt_position - len(encoder_context))
                context_position_ids = list(np.arange(len(encoder_context)) + random_offset)
                skip_item["encoder_position_ids"] = context_position_ids + skip_item["task_position_ids"]
                yield skip_item

            yield item
            previous_item = item


class ResetTimeStep:
    """ ResetTimeStep happens when a task is over and needs to reset the memory
    """

    def __init__(self, tokenizer, task_prompt_position):
        self.tokenizer = tokenizer
        self.task_prompt_position = task_prompt_position
        self.empty_encoder_input_ids = tokenizer.encode("<reset>", add_special_tokens=False)
        self.task_prompt = tokenizer.encode("Reset Memory", add_special_tokens=False)
        self.decoder_prompt = tokenizer.encode("<memory reset>:", add_special_tokens=False)

    def get_reset_item(self):
        task_prompt = [self.tokenizer.sep_token_id] + self.task_prompt + [self.tokenizer.sep_token_id]
        task_position_ids = list(np.arange(len(self.task_prompt)) + self.task_prompt_position)

        encoder_context = [self.tokenizer.cls_token_id] + self.empty_encoder_input_ids + [self.tokenizer.sep_token_id]
        encoder_input_ids = encoder_context + task_prompt
        random_offset = np.random.randint(self.task_prompt_position - len(encoder_context))
        context_position_ids = list(np.arange(len(encoder_context)) + random_offset)
        encoder_position_ids = context_position_ids + task_position_ids

        decoder_input_ids = [self.tokenizer.bos_token_id] + self.decoder_prompt
        decoder_output_ids = decoder_input_ids[1:] + [self.tokenizer.eos_token_id]
        decoder_position_ids = list(np.arange(len(decoder_input_ids)))

        item = {
            "task_prompt": task_prompt,
            "task_position_ids": task_position_ids,
            "encoder_input_ids": encoder_input_ids,
            "encoder_position_ids": encoder_position_ids,
            "decoder_input_ids": decoder_input_ids,
            "decoder_output_ids": decoder_output_ids,
            "decoder_position_ids": decoder_position_ids,
            "reset": True,
        }
        return item


class MultiTaskProcessor:

    def __init__(self, processors: List, probs: List[int]):
        self.processors = processors
        self.probs = probs

    def process_document(self, document):
        task_idx = np.random.choice(len(self.processors), p=self.probs)
        return self.processors[task_idx].process_document(document)


class PretrainProcessor:

    def __init__(self, tokenizer, config):
        proc1 = TextContinuation(tokenizer,
                                 context_length=config.context_length,
                                 max_position_embeddings=config.max_position_embeddings,
                                 task_prompt_position=config.task_prompt_position)

        proc2 = TextInfilling(tokenizer,
                              context_length=config.context_length,
                              max_position_embeddings=config.max_position_embeddings,
                              task_prompt_position=config.task_prompt_position,
                              mask_token_prob=config.mask_token_prob,
                              mask_possion_lambda=config.mask_possion_lambda)

        processor = MultiTaskProcessor([proc1, proc2], probs=config.probs)
        self.processor = RandomSkipProcessor(tokenizer, processor)
        self.reseter = ResetTimeStep(tokenizer, config.task_prompt_position)
        self.skip_reset_prob = config.skip_reset_prob


    def process_document(self, document):
        # reset signal
        if random.random() < self.skip_reset_prob:
            yield self.reseter.get_reset_item()

        for item in self.processor.process_document(document):
            yield item
