import torch
import config
from config import LOG
from data import createLoader
from models.MainModel import DialogueSummarizationModel



model_save_path = "./checkpoints/model_epoch_16.pth"
print(model_save_path)

model = DialogueSummarizationModel()
if LOG: print("Model created")

# Load the saved state dictionary
model.load_state_dict(torch.load(model_save_path))
# Set the model to evaluation mode
model.eval()

_tra, _val, test_dataloader = createLoader()
if LOG: print("Data loaders created")

device = config.DEVICE

with torch.no_grad():
    for batch_idx, batch in enumerate(test_dataloader):
        if LOG: print(f"In Loop x Batch: {batch_idx+1}")
        input_ids, attention_mask, labels, dialog_ID = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        
        if LOG: print("Generating Summary")
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100, num_beams=4, early_stopping=True)
        print("Generated: ", tokenizer.decode(outputs[0], skip_special_tokens=True))
        print("Original: ", tokenizer.decode(labels[0], skip_special_tokens=True))
        print("= = "*10)
        if batch_idx == 2:
            break
        

# for epoch in range(config.NUM_EPOCHS):
#     model.train()
#     total_loss = 0
#     for batch_idx, batch in enumerate(train_dataloader):        
#         optimizer.zero_grad()
        
#         input_ids, attention_mask, labels, dialog_ID = batch
#         input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, dialogue_IDx=dialog_ID.item())
        
#         loss = outputs["loss"]
#         if LOG: print(f"Batch / {batch_idx+1} - - Loss Returned / {loss}")
        
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
#         optimizer.step()
        
#         total_loss += loss.item()
#         if WAND: wandb.log({"Step wise Loss": loss.item()})
    
#     # Print average loss for the epoch
#     average_loss = total_loss / len(train_dataloader)
#     print(f"Epoch {epoch + 1}/{config.NUM_EPOCHS}, Average Training Loss: {average_loss}")
#     if WAND: wandb.log({"Average Loss": average_loss, "Epoch": epoch})

#     # Saving epoch checkpoints
#     model_save_path = os.path.join(config.MODEL_SAVE_PATH, f"model_epoch_{epoch+1}.pth")
#     torch.save(model.state_dict(), model_save_path)
#     print(f"Checkpoint Saved | Epoch: {epoch+1} | Path: {model_save_path}")





