# bekar hai ye file

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset, DataLoader
import pickle
import os
from config import modelDef
from config import LOG
import config



# Step 1: Prepare your data

# Sample dialogues and summaries (replace these with your actual data)
def prepRawData():
    train = {
        "dialogues" : [
            "User: Hi! How are you doing?\nAssistant: I'm good, thank you!", #dialog 1
            "User: What's the weather like today?\nAssistant: I think it's going to rain later." # dialog 2
            # Add more dialogues...
            ],

        "summaries" : [
            "User and assistant exchanged greetings.",
            "User asked about the weather; assistant predicts rain."
            # Add more summaries...
        ]
    }

    val = {
        "dialogues" : [
            "User: Hi! How are you doing?\nAssistant: I'm good, thank you!", #dialog 1
            "User: What's the weather like today?\nAssistant: I think it's going to rain later." # dialog 2
            # Add more dialogues...
            ],

        "summaries" : [
            "User and assistant exchanged greetings.",
            "User asked about the weather; assistant predicts rain."
            # Add more summaries...
        ]
    }

    test = {
        "dialogues" : [
            "User: Hi! How are you doing?\nAssistant: I'm good, thank you!", #dialog 1
            "User: What's the weather like today?\nAssistant: I think it's going to rain later." # dialog 2
            # Add more dialogues...
            ],

        "summaries" : [
            "User and assistant exchanged greetings.",
            "User asked about the weather; assistant predicts rain."
            # Add more summaries...
        ]
    }
    return train, val, test

def tokenizeData():
    tokenizer, model = modelDef()
    train, val, test = prepRawData()

    # Tokenize dialogues and summaries for each set
    tokenized_dialogues_train = tokenizer(train["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_train = tokenizer(train["summaries"], padding="max_length", truncation=True, return_tensors="pt")

    tokenized_dialogues_val = tokenizer(val["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_val = tokenizer(val["summaries"], padding="max_length", truncation=True, return_tensors="pt")

    tokenized_dialogues_test = tokenizer(test["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_test = tokenizer(test["summaries"], padding="max_length", truncation=True, return_tensors="pt")

    if LOG: print("= = "*5+"TOKENIZED"+"= = "*5)

    # Assuming your summaries are used as labels/targets for the model
    labels_train = tokenized_summaries_train["input_ids"]
    labels_val = tokenized_summaries_val["input_ids"]
    labels_test = tokenized_summaries_test["input_ids"]

    # Convert tokenized data to PyTorch tensors
    input_ids_train = tokenized_dialogues_train["input_ids"]
    attention_mask_train = tokenized_dialogues_train["attention_mask"]

    input_ids_val = tokenized_dialogues_val["input_ids"]
    attention_mask_val = tokenized_dialogues_val["attention_mask"]

    input_ids_test = tokenized_dialogues_test["input_ids"]
    attention_mask_test = tokenized_dialogues_test["attention_mask"]

    # Print the tokenized data for verification
    if LOG:
        print("Train set - Input IDs:", input_ids_train.shape)
        print("Validation set - Input IDs:", input_ids_val.shape)
        print("Test set - Input IDs:", input_ids_test.shape)

        print("Train set - Labels:", labels_train.shape)
        print("Validation set - Labels:", labels_val.shape)
        print("Test set - Labels:", labels_test.shape)
    
    return input_ids_train, attention_mask_train, input_ids_val, attention_mask_val, input_ids_test, attention_mask_test, labels_train, labels_val, labels_test



# Step Next: Prepare DataLoader
def createLoader(batch_size = config.BATCH_SIZE):
    input_ids_train, attention_mask_train, input_ids_val, attention_mask_val, input_ids_test, attention_mask_test, labels_train, labels_val, labels_test = None, None, None, None, None, None, None, None, None

    # Specify the paths to save the DataLoaders
    train_dataloader_save_path = config.TRAIN_DATALOADER_SAVE_PATH
    val_dataloader_save_path = config.VAL_DATALOADER_SAVE_PATH
    test_dataloader_save_path = config.TEST_DATALOADER_SAVE_PATH

    # Load the DataLoaders if they exist; otherwise, create them
    if LOG: print("= = "*5+"LOADING DATALOADERS"+" = ="*5)
    if os.path.exists(train_dataloader_save_path):
        with open(train_dataloader_save_path, 'rb') as f:
            train_dataloader = pickle.load(f)
            if LOG: print("Loaded dataloader from:", train_dataloader_save_path)
    else:
        input_ids_train, attention_mask_train, input_ids_val, attention_mask_val, input_ids_test, attention_mask_test, labels_train, labels_val, labels_test = tokenizeData()
        # Create the training DataLoader if it doesn't exist
        train_dataset = TensorDataset(input_ids_train, attention_mask_train, labels_train)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        with open(train_dataloader_save_path, 'wb') as f:
            pickle.dump(train_dataloader, f)
            print("Train dataloader saved at:", train_dataloader_save_path)

    if os.path.exists(val_dataloader_save_path):
        with open(val_dataloader_save_path, 'rb') as f:
            val_dataloader = pickle.load(f)
            if LOG: print("Loaded dataloader from:", val_dataloader_save_path)
    else:
        # Create the validation DataLoader if it doesn't exist
        val_dataset = TensorDataset(input_ids_val, attention_mask_val, labels_val)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        with open(val_dataloader_save_path, 'wb') as f:
            pickle.dump(val_dataloader, f)
            print("Validation dataloader saved at:", val_dataloader_save_path)

    if os.path.exists(test_dataloader_save_path):
        with open(test_dataloader_save_path, 'rb') as f:
            test_dataloader = pickle.load(f)
            if LOG: print("Loaded dataloader from:", test_dataloader_save_path)
    else:
        # Create the test DataLoader if it doesn't exist
        test_dataset = TensorDataset(input_ids_test, attention_mask_test, labels_test)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        with open(test_dataloader_save_path, 'wb') as f:
            pickle.dump(test_dataloader, f)
            if LOG: print("Test dataloader saved at:", test_dataloader_save_path)

    # Print the number of batches in each set
    # if LOG:
    #     print("Number of batches in train DataLoader:", len(train_dataloader))
    #     print("Number of batches in validation DataLoader:", len(val_dataloader))
    #     print("Number of batches in test DataLoader:", len(test_dataloader))
    
    return train_dataloader, val_dataloader, test_dataloader
