import os
import sys
import argparse

import numpy as np
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.file_utils import TorchFileModule
from utils.training_utils import replace_unicode_punct, logger


class Seq2SeqDataset(Dataset):
    def __init__(self, tokenizer, args, filename=None):
        super().__init__()
        self.args = args
        self.tokenizer = tokenizer
        self._declare()
        self.fileutils = TorchFileModule()
        self.docs = self.fileutils.read_csv(filename)
        self.len = len(self.docs)

    def _declare(self):
        self.max_len = self.args.max_len
        self.pad_index = self.tokenizer.pad_token_id
        self.eos_index = self.tokenizer.eos_token_id
        self.bos_index = self.tokenizer.bos_token_id
        self.mask_index = self.tokenizer.mask_token_id
        self.tokenizer.pad_token = '<pad>'

    def add_padding_data(self, inputs: np.ndarray, left=False):
        if len(inputs) < self.max_len:
            pad = np.array([self.pad_index] * (self.max_len - len(inputs)))
            if left:
                inputs = np.concatenate([pad, inputs])
            else:
                inputs = np.concatenate([inputs, pad])
        inputs = inputs[:self.max_len]
        return np.array(inputs, dtype=np.int_)

    def prepare(self, instance):
        psg = self.tokenizer.encode(replace_unicode_punct(instance['content']))
        q = self.tokenizer.encode(replace_unicode_punct(instance['question']))
        a = self.tokenizer.encode(replace_unicode_punct(instance['answer']))

        if 'bart' in self.args.model_type:
            psg = psg[1:]
            q = q[1:]
            a = a[1:]
        elif 't5' in self.args.model_type:
            psg_prefix = self.tokenizer.encode(replace_unicode_punct('context: '))[:-1]
            q_prefix = self.tokenizer.encode(replace_unicode_punct('question: '))[:-1]
            a_prefix = self.tokenizer.encode(replace_unicode_punct('answer: '))[:-1]
            pass
        else:
            raise Exception('model type error')

        if self.args.training_type == 1:
            # psg + q ==> a
            input_ids_tmp = [q, psg]
            label_ids_tmp = [a]
        elif self.args.training_type == 2:
            q_input = q[:1]
            if q_input[-1] != self.eos_index:
                q_input = q_input + [self.eos_index]
            input_ids_tmp = [q_input, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 3:
            q_input = q[:2]
            if q_input[-1] != self.eos_index:
                q_input = q_input + [self.eos_index]
            input_ids_tmp = [q_input, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 4:
            q_input = q[:3]
            if q_input[-1] != self.eos_index:
                q_input = q_input + [self.eos_index]
            input_ids_tmp = [q_input, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 5:
            q_input = q[:4]
            if q_input[-1] != self.eos_index:
                q_input = q_input + [self.eos_index]
            input_ids_tmp = [q_input, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 6:
            q_input = q[:5]
            if q_input[-1] != self.eos_index:
                q_input = q_input + [self.eos_index]
            input_ids_tmp = [q_input, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 7:
            input_ids_tmp = [q, psg]
            label_ids_tmp = [q, a]
        elif self.args.training_type == 8:
            # psg ==> q + a
            input_ids_tmp = [psg]
            label_ids_tmp = [q, a]
        else:
            raise Exception('training type error')

        if 'bart' in self.args.model_type:
            input_ids, label_ids = [self.bos_index], []
            for i in input_ids_tmp:
                input_ids = input_ids + i
            for i in label_ids_tmp:
                label_ids = label_ids + i
            dec_input_ids = [self.bos_index] + label_ids[:-1]
        elif 't5' in self.args.model_type:
            if self.args.training_type == 8:
                input_ids = psg_prefix + input_ids_tmp[0]
            else:
                input_ids = q_prefix + input_ids_tmp[0] + psg_prefix + input_ids_tmp[1]
            label_ids = a_prefix
            for i in label_ids_tmp:
                label_ids = label_ids + i[:-1]
            label_ids = label_ids + [self.eos_index]
            dec_input_ids = label_ids[:-1]
            label_ids = label_ids[1:]
        else:
            raise Exception('model type error')

        return {
            'input_ids': self.add_padding_data(input_ids, left=False),
            'label_ids': self.add_padding_data(label_ids, left=False),
            'dec_input_ids': self.add_padding_data(dec_input_ids, left=False),
        }

    def __getitem__(self, idx):
        instance = self.docs.iloc[idx]
        output = self.prepare(instance)
        return output

    def __len__(self):
        return self.len


class DataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, args):
        super().__init__()
        self.tokenizer = tokenizer
        self.hparam_args = args
        self.batch_size = args.batch_size
        self.max_len = args.max_len
        self.train_file = args.train_file
        self.valid_file = args.valid_file

        self.num_workers = args.num_workers
        self.train = Seq2SeqDataset(
            tokenizer=self.tokenizer,
            args=self.hparam_args,
            filename=self.train_file
        )
        self.valid = Seq2SeqDataset(
            tokenizer=self.tokenizer,
            args=self.hparam_args,
            filename=self.valid_file
        )

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(
            parents=[parent_parser], add_help=False)
        return parser

    def train_dataloader(self):
        pin_memory = True
        if self.num_workers > 0: pin_memory = False
        train_dataloader = DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, shuffle=True)
        return train_dataloader

    def val_dataloader(self):
        pin_memory = True
        if self.num_workers > 0: pin_memory = False
        val_dataloader = DataLoader(self.valid, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=pin_memory, shuffle=True)
        return val_dataloader


