import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer, AdamW, T5Config
from dataprocess.data_preprocess import T5Preprocessor
from T5_IP import CustomT5Model


class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]
        source = self.tokenizer.encode_plus(input_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")
        target = self.tokenizer.encode_plus(target_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")

        return source.input_ids.squeeze(), target.input_ids.squeeze()


preprocessor = T5Preprocessor()
data = preprocessor.load_and_preprocess('your_dataset_path')


tokenizer = T5Tokenizer.from_pretrained('t5-small')


dataset = CustomDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


config = T5Config.from_pretrained('t5-small')
model = CustomT5Model(config)
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epoch=10

model.train()
for epoch in range(num_epoch):
    for batch in dataloader:
        optimizer.zero_grad()
        input_ids, labels = batch
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")


