import os
import copy
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from memformer_encoder import BartEncoderSpanQAFlyModel
from utils import Processor


class SpanQADataset(Dataset):
    def __init__(self, tokenizer: AutoTokenizer, dataset) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.dataset = dataset

    def __getitem__(self, index):
        item = self.dataset[index]
        return item

    def __len__(self):
        return len(self.dataset)

    def collate_fn(self, batch):
        input_ids = [torch.LongTensor(item["input_ids"]) for item in batch]
        start_positions = torch.LongTensor([item["start_positions"] for item in batch])
        end_positions = torch.LongTensor([item["end_positions"] for item in batch])
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = torch.BoolTensor(input_ids != self.tokenizer.pad_token_id)
        batch = {
            "input_ids": input_ids,
            "start_positions": start_positions,
            "end_positions": end_positions,
            "attention_mask": attention_mask,
        }
        return batch


class DataLoaderHelper:
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
        self.processor = Processor(self.tokenizer)
        self.dataset = load_dataset(self.config.task.dataset)
        self.dataset = self.dataset.map(
            self.processor.preprocess_function, batched=True, remove_columns=self.dataset["train"].column_names
        )

        indices = np.arange(len(self.dataset["train"]))
        np.random.seed(42)
        np.random.shuffle(indices)
        split = int(len(indices) * self.config.task.train_portion)
        self.dataset["train"] = self.dataset["train"].select(indices[:split])

    def train_loader_fn(self):
        dataset = SpanQADataset(self.tokenizer, self.dataset["train"])
        dataloader = DataLoader(
            dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=dataset.collate_fn
        )
        return dataloader

    def valid_loader_fn(self):
        dataset = SpanQADataset(self.tokenizer, self.dataset["validation"])
        dataloader = DataLoader(
            dataset, batch_size=self.config.training.evaluation.batch_size, shuffle=False, collate_fn=dataset.collate_fn
        )
        return dataloader


config_path = "./squad_config/base_repeat_10%.yaml"


def main():
    # we recommand adding this function before everything starts
    if "RANK" in os.environ:
        torch.distributed.init_process_group(backend="nccl", init_method="env://")

    config = FlyConfig.load(config_path)
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    model = BartEncoderSpanQAFlyModel(config)
    trainer = Trainer(config.training, model)

    model.configure_metrics()

    with FlyLogger(config.flylogger) as flylogger:
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)

        trainer.train(config.training, train_dataloader, valid_dataloader)


if __name__ == "__main__":
    main()
