import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AdamW
# from transformers import LlamaTokenizer, LlamaForCausalLM
from torch.utils.data import TensorDataset, DataLoader
import pickle
import os
# from config import modelDef
from config import LOG
import config


def createLoader():
    tokenizer = config.tokenizerCheckpoint
    train, val, test = prepRawData()

    print("Data Loaded")
    print(len(train["dialogues"]))
    # Tokenize dialogues and summaries for each set
    tokenized_dialogues_train = tokenizer(train["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_train = tokenizer(train["summaries"], padding="max_length", truncation=True, return_tensors="pt")
    # tokenized_dialogues_train = tokenizer.prepare_seq2seq_batch(train["dialogues"], train["summaries"], padding="max_length", truncation=True, return_tensors="pt")


    print("Tokenized Train")
    tokenized_dialogues_val = tokenizer(val["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_val = tokenizer(val["summaries"], padding="max_length", truncation=True, return_tensors="pt")
    # tokenized_dialogues_val = tokenizer.prepare_seq2seq_batch(val["dialogues"], val["summaries"], padding="max_length", truncation=True, return_tensors="pt")

    tokenized_dialogues_test = tokenizer(test["dialogues"], padding="max_length", truncation=True, return_tensors="pt")
    tokenized_summaries_test = tokenizer(test["summaries"], padding="max_length", truncation=True, return_tensors="pt")
    # tokenized_dialogues_test = tokenizer.prepare_seq2seq_batch(test["dialogues"], test["summaries"], padding="max_length", truncation=True, return_tensors="pt")

    if LOG: print("= = "*5+"TOKENIZED"+"= = "*5)

    batch_size = config.BATCH_SIZE

    # Create DataLoader
    train_dataset = torch.utils.data.TensorDataset(
        tokenized_dialogues_train["input_ids"],
        tokenized_dialogues_train["attention_mask"],
        tokenized_summaries_train["input_ids"],
        torch.tensor(train["id"])
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = torch.utils.data.TensorDataset(
        tokenized_dialogues_val["input_ids"],
        tokenized_dialogues_val["attention_mask"],
        tokenized_summaries_val["input_ids"],
        torch.tensor(val["id"])
    )   
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = torch.utils.data.TensorDataset(
        tokenized_dialogues_test["input_ids"],
        tokenized_dialogues_test["attention_mask"],
        tokenized_summaries_test["input_ids"],
        torch.tensor(test["id"])
    )
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    if LOG: print("= = "*5+"DATALOADERS CREATED"+"= = "*5)
    return train_dataloader, val_dataloader, test_dataloader


# Sample dialogues and summaries (replace these with your actual data)
def prepRawData():
    df = pickle.load(open('dataset/data/dataframes.pkl', 'rb'))
    train = df['train']
    test = df['test']
    val = df['val']
    return train, val, test


createLoader()