import torch
import datetime
from data import createLoader
from transformers import AutoTokenizer, AutoModel, AdamW
from models.MainModel import DialogueSummarizationModel
from torch import nn, optim
import os
import pickle
from config import LOG, WAND
import config
from transformers import AutoTokenizer, BartForConditionalGeneration
# from eval import SummarizationEvaluator
import wandb

torch.autograd.set_detect_anomaly(True)
if WAND: wandb.init(project="TACL", name="mentallamma_multiGPU")

model = DialogueSummarizationModel()
train_dataloader, val_dataloader, test_dataloader = createLoader()

optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
device = config.DEVICE

model.to(device)
if LOG: print("Model shifted to device: ", device)

clip_value = 5
if LOG: print("= = "*5+"TRAINING STARTED"+" = ="*5)
model.train()
for epoch in range(config.NUM_EPOCHS):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):        
        optimizer.zero_grad()
        
        input_ids, attention_mask, labels, dialog_ID = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, dialogue_IDx=dialog_ID.item())
        
        loss = outputs["loss"]
        if LOG: print(f"Batch / {batch_idx+1} - - Loss Returned / {loss}")
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        
        total_loss += loss.item()
        if WAND: wandb.log({"Step wise Loss": loss.item()})
    
    # Print average loss for the epoch
    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{config.NUM_EPOCHS}, Average Training Loss: {average_loss}")
    if WAND: wandb.log({"Average Loss": average_loss, "Epoch": epoch})

    # Saving epoch checkpoints
    # model_save_path = os.path.join(config.MODEL_SAVE_PATH, f"model_epoch_{epoch+1}.pth")
    # torch.save(model.state_dict(), model_save_path)
    # print(f"Checkpoint Saved | Epoch: {epoch+1} | Path: {model_save_path}")

    # model_save_path = os.path.join(config.MODEL_SAVE_PATH, f"hf_model_epoch_{epoch+1}.pth")
    # torch.save(model.state_dict(), model_save_path)
    print(f"HF Checkpoint Saved | Epoch: {epoch+1} | Path: {config.MODEL_SAVE_PATH}hf_model_epoch{epoch+1}")
    model.save_pretrained(f"{config.MODEL_SAVE_PATH}hf_model_epoch{epoch+1}", push_to_hub=False)
    

now = datetime.datetime.now()
nowStamp = now.strftime("%m_%d__%HH%MM")
config.LAST_SAVED = f'{config.MODEL_SAVE_PATH}final_model_{config.MODEL.split("/")[1]}_DT_{nowStamp}.pth'
torch.save(model.state_dict(), config.LAST_SAVED)
print(f"Final Model Saved | Path: {config.MODEL_SAVE_PATH}final_model_{config.MODEL.split('/')[1]}_DT_{nowStamp}.pth")



