from typing import List
import math
import os
import contextlib
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pyarrow as pa
import random
from transformers import AutoTokenizer
from torchfly.distributed import get_rank, get_world_size
import logging

from data_processors.pretrain_processor import PretrainProcessor


logger = logging.getLogger("dataloader")

# pylint:disable=no-member


def shift_tokens_right(
    input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


@contextlib.contextmanager
def local_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


class DocumentAgent:
    def __init__(self, config, document_table, indices, processor):
        self.epoch = 0
        self.config = config
        self.document_table = document_table
        self.indices = indices
        self.processor = processor

    def __iter__(self):
        while True:
            with local_seed(123 + self.epoch):
                np.random.shuffle(self.indices)

            for idx in self.indices:
                tokens = self.document_table[idx].as_py()
                for item in self.processor(tokens):
                    yield item
            self.epoch += 1


class DocumentDataLoader:
    def __init__(self, config, arrow_files, batch_size):
        self.agents = []
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
        self.processor = PretrainProcessor(self.tokenizer, config)

        pa_table = self.get_arrow_table(arrow_files)

        document_table = pa_table["token_ids"]
        indices = np.arange(len(document_table))
        split_size = (len(indices) // batch_size) + 1

        with local_seed(1203):
            np.random.shuffle(indices)

        for idx in range(batch_size):
            sub_indices = indices[idx * split_size : (idx + 1) * split_size]

            agent = DocumentAgent(config, document_table, sub_indices, self.processor,)
            self.agents.append(iter(agent))

        self.pad_token_id = self.tokenizer.pad_token_id

    def get_arrow_table(self, arrow_files):
        with local_seed(1203):
            np.random.shuffle(arrow_files)
        
        mmap = pa.memory_map(arrow_files[0])
        print(f"Reading file {arrow_files[0]}")
        pa_table = pa.ipc.RecordBatchFileReader(mmap).read_all()

        for i in range(1, len(arrow_files)):
            mmap = pa.memory_map(arrow_files[i])
            print(f"Reading file {arrow_files[i]}")
            sub_table = pa.ipc.RecordBatchFileReader(mmap).read_all()
            pa_table = pa.concat_tables([pa_table, sub_table])
        return pa_table

    def __iter__(self):
        while True:
            batch = {}
            items = [next(agent) for agent in self.agents]

            source = [item[0][0] for item in items]
            target = [item[0][1] for item in items]
            batch["reset"] = torch.FloatTensor([item[1] for item in items])

            batch["encoder_input_ids"] = pad_sequence(
                source, batch_first=True, padding_value=self.pad_token_id,
            )
            # set to -100 set that the loss is ignored
            batch["target"] = pad_sequence(
                target, batch_first=True, padding_value=-100,
            )
            batch["decoder_input_ids"] = shift_tokens_right(
                batch["target"],
                pad_token_id=self.tokenizer.pad_token_id,
                decoder_start_token_id=self.tokenizer.eos_token_id,
            )

            batch["encoder_attention_mask"] = batch["encoder_input_ids"] != self.pad_token_id
            batch["decoder_attention_mask"] = batch["decoder_input_ids"] != self.pad_token_id

            yield batch


class DataLoaderHelper:
    def __init__(self, config, training_batch_size):
        self.config = config
        self.training_batch_size = training_batch_size

    def train_dataloader_fn(self):
        with open("./data_files.txt") as f:
            arrow_files = f.readlines()
            arrow_files = [file.strip() for file in arrow_files]

        rank = get_rank()
        world_size = get_world_size()
        split_size = math.ceil(len(arrow_files) / world_size)

        # each rank would get a different split of files
        with local_seed(1203):
            np.random.shuffle(arrow_files)
            arrow_files = arrow_files[rank * split_size : (rank + 1) * split_size]

        dataloader = DocumentDataLoader(
            config=self.config,
            arrow_files=arrow_files,
            batch_size=self.training_batch_size,
        )
        return dataloader
