import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import json
from tqdm import tqdm
import torch
import wandb
import numpy as np
from torch.utils.data import ConcatDataset
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
from GPT_GNN.config import *
from preprocess import Dataset, AligNet
from GPT_GNN.model import GNN, MLMGNN, SmallGraphError


torch.manual_seed(37)
scaler = torch.cuda.amp.GradScaler()


def load_graph(dataset, file, collection):
    try:
        col = dataset.train_graphs if collection == 'train' else dataset.test_graphs
        return Dataset.to_data(os.path.join(col, file))
    except RuntimeError:
        print("Skipping ", file)
    except json.decoder.JSONDecodeError:
        print("Json error. Skipping ", file)


def pretrain(dataset, model_type, path, device_ids, batch_size, epochs, num_workers):
    # torch.cuda.memory._record_memory_history(enabled=True)
    wandb.init(
        # set the wandb project where this run will be logged
        project="gnn"
        
        # track hyperparameters and run metadata
        # config={
        # "learning_rate": 0.02,
        # "architecture": "CNN",
        # "dataset": "CIFAR-100",
        # "epochs": 10,
        # }
    )
    device = torch.device("cuda")
    gnn = GNN() if model_type == "gengnn" else MLMGNN()
    gnn = DataParallel(gnn, device_ids=device_ids)
    print('# trainable params: ', sum(p.numel() for p in gnn.parameters() if p.requires_grad))
    gnn.to(device)
    wandb.watch(gnn, log='all', log_freq=100)
    optimizer = torch.optim.AdamW(gnn.parameters(), lr = 5e-6)  #  * batch_size / 16)
    gnn.train()
    dataset_train = AligNet(dataset.train_home, dataset.name)
    train_loader = DataListLoader(dataset_train, shuffle=True, batch_size=batch_size, num_workers=num_workers)
    min_loss = 10000
    for epoch in range(0, epochs):
        step = 0
        epoch_losses = []
        for batch in tqdm(train_loader):
            step += 1
            try:
                node_loss, edge_loss, comm_loss, segment_loss, block_loss, mulco_loss = gnn(batch)
                # print(node_loss.mean(), edge_loss.mean(), segment_loss.mean(), block_loss.mean(), mulco_loss.mean())
                with torch.cuda.amp.autocast():
                    if model_type == "gengnn":
                        loss = node_loss.mean() + edge_loss.mean() + block_loss.mean()  # + mulco_loss.mean() #  + comm_loss.mean()
                    else:
                        loss = block_loss.mean()  # + mulco_loss.mean()
                    epoch_losses.append(loss.item())
                optimizer.zero_grad() 
                # torch.cuda.empty_cache()
                # loss.backward()
                scaler.scale(loss).backward()
                if model_type == "gengnn":
                    wandb.log({"epoch": epoch, "train loss - node": node_loss.mean(), "train loss - edge": edge_loss.mean(), "train loss - comm": comm_loss.mean(), "train loss - block": block_loss.mean(), "train loss - segment": segment_loss.mean(), "train loss - mulco": mulco_loss.mean(), 'train loss - total': loss})
                else:
                    wandb.log({"epoch": epoch, "train loss - block": block_loss.mean(), "train loss - segment": segment_loss.mean(), "train loss - mulco": mulco_loss.mean(), 'train loss - total': loss})
                # torch.nn.utils.clip_grad_norm_(gnn.parameters(), 10.0)
                # optimizer.step()
                scaler.step(optimizer)
                # Updates the scale for next iteration
                scaler.update()
                # del batch
              # torch.cuda.empty_cache()
            except torch.cuda.OutOfMemoryError:
                del batch
                # torch.cuda.empty_cache()
                print('Graph too large. Skipping.')
            except SmallGraphError:
                del batch
                # torch.cuda.empty_cache()
                print('Graph with no edge. Skipping.')
                continue
            except AttributeError:
                del batch
                # torch.cuda.empty_cache()
                print('Empty graph. Skipping.')
            except ValueError:
                del batch
                # torch.cuda.empty_cache()
                print('An error occurred (ValueError). Skipping the batch.')
            except RuntimeError:
                del batch
                # torch.cuda.empty_cache()
                print('An error occurred (RuntimeError). Skipping the batch.')
            if step % 1000 == 0:
                torch.save(gnn.module, path + '_' + dataset.name + '_' + str(epoch) + '_sofar' + '.bin')
        epoch_loss = np.median(np.array(epoch_losses))
        print(min_loss, np.array(epoch_losses))
        if epoch_loss < min_loss:
            print('Checkpointing at epoch', epoch, 'with current minimum loss of', epoch_loss)
            torch.save(gnn.module, path + '_' + dataset.name + '_best' +'.bin')
            min_loss = epoch_loss
        # torch.save(gnn.module, path + '_' + dataset.name + '_' + str(epoch) +'.bin')
    # try:
    #    torch.cuda.memory._save_memory_usage()
    # except Exception as e:
    #     print(f"Failed to capture memory snapshot {e}")
    # torch.cuda.memory._record_memory_history(enabled=False)
    wandb.unwatch()
    wandb.finish()
    torch.save(gnn.module, path + '_' + dataset.name + '_' + 'final.bin')
    return path


if __name__ == "__main__":
    import os
    from GPT_GNN.config import funsd_home, docvqa_home, cord_home, sroie_home, buddie_home
    from preprocess import IDL, RVL, FUNSD, DocVQA, CORD, SROIE, BUDDIE

    parser = argparse.ArgumentParser(description='AliGATr pretrainer.')
    parser.add_argument('-d','--dataset', type=str, help='Name of datasett to be used for pretraining', choices=['idl', 'rvl-cdip', 'funsd', 'cord', 'sroie', 'buddie'], default='funsd')
    parser.add_argument('-m','--model', type=str, help='Type of GNN model to be used', choices=['gengnn', 'mlmgnn'], default='gengnn')
    parser.add_argument('-g','--gpu_num', type=int, help='Number of available GPUs', default=4)
    parser.add_argument('-e','--epochs', type=int, help='Number of epochs', default=50)
    parser.add_argument('-b','--batch_size', type=int, help='Batch size', default=16)
    parser.add_argument('-w','--num_workers', type=int, help='Number of workers for DataListLoader', default=8)
    parser.add_argument('-o','--output_path', type=str, help='Where to save the model weights', default='/path/to/model')
    args = vars(parser.parse_args())


    device_ids = list(range(args['gpu_num']))  # [0, 1, 2, 3]
    epochs = args['epochs']
    batch_size = args['batch_size']
    num_workers = args['num_workers']
    path =  args['output_path']

    datasets = {'idl': IDL(idl_home),
                'rvl-cdip': RVL(rvl_home),
                'funsd': FUNSD(funsd_home),
                'cord': CORD(cord_home),
                'sroie': SROIE(sroie_home),
                'buddie': BUDDIE(buddie_home)}
    dataset =  datasets[args['dataset']]  # FUNSD(funsd_home)

    model_type = args['model']
    # dataset = CORD(cord_home)
    # dataset = DocVQA(docvqa_home)
    pretrain(dataset, model_type, path, device_ids, batch_size, epochs, num_workers)