import os
import copy
import pickle
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BartTokenizer
import logging

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from memformer import MemformerClassificationFlyModel

logger = logging.getLogger(__name__)

class ClassificationDataset(Dataset):
    def __init__(self, tokenizer: BartTokenizer, dataset, max_length=128) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.dataset = [
            item for item in dataset if len(item["input_ids"]) <= max_length
        ]

    def __getitem__(self, index):
        item = copy.deepcopy(self.dataset[index])
        premise, hypo = item["src_text"].split("</s>")
        return item

    def __len__(self):
        return len(self.dataset)

    def collate_fn(self, batch):
        input_ids = [torch.LongTensor(item["input_ids"]) for item in batch]
        labels = torch.LongTensor([item["label"] for item in batch])
        input_ids = pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        attention_mask = torch.BoolTensor(input_ids != self.tokenizer.pad_token_id)
        batch = {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }
        return batch


class DataLoaderHelper:
    def __init__(self, config):
        self.config = config
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

    def train_loader_fn(self):
        with open(self.config.task.train_file, "rb") as f:
            dataset = pickle.load(f)

        with open(self.config.task.train_indices_file, "rb") as f:
            indices = pickle.load(f)

        split = int(len(indices) * self.config.task.train_portion)
        dataset = [dataset[idx] for idx in indices[:split]]

        dataset = ClassificationDataset(self.tokenizer, dataset)
        dataloader = DataLoader(
            dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            collate_fn=dataset.collate_fn,
        )
        return dataloader

    def valid_loader_fn(self):
        with open(self.config.task.valid_file, "rb") as f:
            dataset = pickle.load(f)
        dataset = ClassificationDataset(self.tokenizer, dataset)
        dataloader = DataLoader(
            dataset,
            batch_size=self.config.training.evaluation.batch_size,
            shuffle=False,
            collate_fn=dataset.collate_fn,
        )
        return dataloader


config_path = "./mnli_config/memformerA4_base_10%.yaml"


def main():
    # we recommand adding this function before everything starts
    if "RANK" in os.environ:
        torch.distributed.init_process_group(backend="nccl", init_method="env://")

    config = FlyConfig.load(config_path)
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    model = MemformerClassificationFlyModel(config)
    trainer = Trainer(config.training, model)

    model.configure_metrics()

    with FlyLogger(config.flylogger) as flylogger:
        logger.info(OmegaConf.to_yaml(config))

        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)
        trainer.train(config.training, train_dataloader, valid_dataloader)


if __name__ == "__main__":
    main()
