from torch.utils.data import DataLoader
import pytorch_lightning as pl
from .tokenization_eojeol import *
import pandas as pd
import json
from .Dataset import *
from torch.nn.utils.rnn import pad_sequence

class BackChannelDataModule(pl.LightningDataModule):
    def __init__(self, kfold, data_path, bert_path, audio_path, **kwargs):
        super().__init__()
        self.kfold = kfold
        self.data_path = data_path
        self.bert_path = bert_path
        self.audio_path = audio_path
        self.save_hyperparameters()

    def setup(self, stage=None):
        tokenizer = BertTokenizer(
            f"{self.bert_path}/vocab.korean.rawtext.list", 
            do_lower_case=False
        )
        print("Load Dataset...")
        text_df = pd.read_csv(f"{self.data_path}/Backchannel_text.csv")

        with open(f"{self.data_path}/cross_validation_list.json") as f:
            file_names = json.load(f)[str(self.kfold)]
            train_file_name = file_names["train"]
            valid_file_name = file_names["valid"]
            tset_file_name = file_names["test"]

        self.train_ids = text_df[text_df["file"].isin(train_file_name)].file_index
        train_text_df = text_df[text_df["file_index"].isin(self.train_ids)].reset_index(drop=True)

        self.valid_ids = text_df[text_df["file"].isin(valid_file_name)].file_index
        valid_text_df = text_df[text_df["file_index"].isin(self.valid_ids)].reset_index(drop=True)
        
        self.test_ids = text_df[text_df["file"].isin(tset_file_name)].file_index
        test_text_df = text_df[text_df["file_index"].isin(self.test_ids)].reset_index(drop=True)

        self.train_dataset = BackChannelDataset(train_text_df, tokenizer, self.hparams, self.audio_path)
        self.valid_dataset = BackChannelDataset(valid_text_df, tokenizer, self.hparams, self.audio_path)
        self.test_dataset = BackChannelDataset(test_text_df, tokenizer, self.hparams, self.audio_path)
        
        print("Total : ", text_df.shape[0])
        print("TRAIN : ", len(self.train_ids), self.train_dataset.bc_d_nums)
        print("VALID : ", len(self.valid_ids), self.valid_dataset.bc_d_nums )
        print("TEST : ", len(self.test_ids), self.test_dataset.bc_d_nums )
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=collate_fn)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.hparams.batch_size, collate_fn=collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, collate_fn=collate_fn)
        

def collate_fn(data):
    """
    Collate and process a list of data samples into batched tensors.

    Args:
        data (List[Dict]): List of data samples, each containing various input and target tensors.

    Returns:
        Dict[str, Tensor]: A dictionary containing batched tensors for various inputs and targets.
                           - "input_ids" (Tensor): Batched input token IDs. Shape: [batch_size x (history_size + 1), max_sequence_length]
                           - "input_mask" (Tensor): Batched input token masks. Shape: [batch_size x (history_size + 1), max_sequence_length]
                           - "token_len" (Tensor): Batched token lengths. Shape: [batch_size x (history_size + 1), 1]
                           - "target" (Tensor): Batched target tensors. Shape: [batch_size]
                           - "history_size" (Tensor): Batched history sizes. Shape: [batch_size, 1]
                           - "accustic_feature" (Tensor or None): Batched acoustic features, if provided. Shape: [batch_size, max_audio_length, feature_dim] or None
    """

    input_ids = [d["cur_input_ids"] for d in data if d != -1]
    input_mask = [d["cur_input_mask"] for d in data if d != -1]
    token_len = [d["cur_token_len"] for d in data if d != -1]
    history_size = [d["history_size"] for d in data if d != -1]
    target = [d["target"] for d in data if d != -1]

    if(data[0].get("accustic_feature", None) is not None):
        accustic_feature = [d["accustic_feature"] for d in data if d != -1]
        accustic_feature = pad_sequence(accustic_feature, batch_first=True)
    else:
        accustic_feature = None

    if(data[0].get("his_input_ids", None) is not None):
        input_ids = [d["his_input_ids"] for d in data if d != -1] + input_ids
        input_mask = [d["his_input_mask"] for d in data if d != -1] + input_mask
        token_len = [d["his_token_len"] for d in data if d != -1] + token_len

    return {
        "input_ids":torch.vstack(input_ids),
        "input_mask":torch.vstack(input_mask),
        "token_len":torch.vstack(token_len),
        "target":torch.vstack(target).reshape(-1),
        "history_size":torch.vstack(history_size).reshape(-1),
        "accustic_feature":accustic_feature
    }