import sys
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import torch
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
import evaluate
from sklearn.metrics import f1_score, confusion_matrix
from tqdm import tqdm
import wandb
from GPT_GNN.config import *
from GPT_GNN.model import GNN, LinkDetector, SmallGraphError
from preprocess import Dataset, FUNSD, CORD, BUDDIE


torch.manual_seed(37)
metric = evaluate.load("seqeval")


def print_results(filepath, dataset, preds, y):
    import json
    with open(filepath.replace('graphs', 'annotations').replace('.json.json', '.json')) as f:
        j = json.load(f)
        i = -1
        for item in j['form']:
            for word in item['words']:
                if 'text' in word and len(word['text']) > 0:
                    i += 1
                    print(word['text'], dataset.idx_to_label[preds[i]], dataset.idx_to_label[y[i]])


def load_model(model_type, path, dataset, train_loader, val_loader, batch_size, device, device_ids, cont=False):
    if not cont:
        gnn = torch.load(path)
        model = LinkDetector(gnn, model_type)
        # lr = optimize_lr(model, train_loader, val_loader, batch_size, device, device_ids)
        # model = NodeClassifier(dataset.num_classes, gnn)
        # lr of 1e-5 works best for roberta-only
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    else:
        path = path.replace('.bin', '_best.bin')
        checkpoint = torch.load(path)
        model = LinkDetector(None, model_type)
        # model = DataParallel(model, device_ids=device_ids)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer
    

def finetune_and_test(dataset, model_type, path, device_ids, batch_size, epochs, cont=False):

    wandb.init(
        # set the wandb project where this run will be logged
        project="gnn"
        # track hyperparameters and run metadata
        # config={
        # "learning_rate": 0.02,
        # "architecture": "CNN",
        # "dataset": "CIFAR-100",
        # "epochs": 10,
        # }
    )
    train = list(os.listdir(dataset.train_graphs))
    val = list(os.listdir(dataset.val_graphs))
    test = list(os.listdir(dataset.test_graphs))
    # test = [test[10]]
    train_loader = DataListLoader([Dataset.to_data(os.path.join(dataset.train_graphs, file)) for file in tqdm(train)], batch_size=batch_size, shuffle=True)
    val_loader = DataListLoader([Dataset.to_data(os.path.join(dataset.val_graphs, file)) for file in tqdm(val)], batch_size=batch_size, shuffle=False)
    test_loader = DataListLoader([Dataset.to_data(os.path.join(dataset.test_graphs, file)) for file in tqdm(test)], batch_size=batch_size, shuffle=False)
    device = torch.device("cuda")
    model, optimizer = load_model(model_type, path, dataset, train_loader, val_loader, batch_size, device, device_ids, cont)
    # # freeze pretraining parameters
    # for param in model.gnn.parameters():
    #     param.requires_grad = Falses
    # model = DataParallel(model, device_ids=device_ids)
    print('# trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
    model.to(device)
    wandb.watch(model, log="gradients", log_freq=10)
    # lr = optimize_lr(model, train_loader, val_loader, batch_size)
    # optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4 * batch_size / 16)
    best_performance = 0.0
    # i, ji = model.module.i_counts, model.module.j_given_i_counts
    for epoch in range(0, epochs):
        # i_counts, j_given_i_counts, transition_probs = NodeClassifier.init_params(dataset.num_classes)
        model.train()
        for batch in tqdm(train_loader):
            try:
                preds, trues, edge_loss = model(batch)
                # transition_probs = NodeClassifier.update_transition_probs(y, dataset.num_classes, i_counts, j_given_i_counts)
                # score, best_seq = NodeClassifier.viterbi_decoding_torch(emission_probs, transition_probs, dataset.num_classes)
                loss = edge_loss.mean()
                optimizer.zero_grad() 
                loss.backward()
                # print(loss, model.link_classifier.dense1.weight.grad, model.link_classifier.dense2.weight.grad)
                wandb.log({"epoch": epoch, "train loss - link pred": loss})
                # torch.nn.utils.clip
                # _grad_norm_(gnn.parameters(), 10.0)
                optimizer.step()
            except torch.cuda.OutOfMemoryError:
                print('Graph too large. SKipping.')
            except SmallGraphError:
                print('Graph too small, or no edges present. SKipping.')
        ts, ps = [], []
        model.eval()
        with torch.no_grad():
            # model.module.transition_probs = torch.nn.Parameter(NodeClassifier.sync_transition_probs(i, ji, dataset.num_classes), requires_grad=False)
            for batch in tqdm(test_loader):
                try:
                    preds, trues, edge_loss = model(batch)
                    # transition_probs = NodeClassifier.update_transition_probs(y, dataset.num_classes, i_counts, j_given_i_counts)
                    # score, best_seq = NodeClassifier.viterbi_decoding_torch(emission_probs, transition_probs, dataset.num_classes)
                    loss = edge_loss.mean()
                    wandb.log({"epoch": epoch, "val loss - link pred": loss}) 
                    ps += (torch.sigmoid(preds)>=0.5).long().detach().cpu().tolist()
                    ts += trues.detach().cpu().tolist()
                except torch.cuda.OutOfMemoryError:
                    print('Graph too large. SKipping.') 
                except SmallGraphError:
                    print('Graph too small, or no edges present. SKipping.')
        test_f1 = f1_score([x for y in ts for x in y], [x for y in ps for x in y], average='macro')
        performance = test_f1
        if performance > best_performance:
            print('Saving best model with f1 of ...', performance)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, path.replace('.bin', '_best.bin'))
            best_performance = performance
        wandb.log({"epoch": epoch, "val f1 - link pred": performance})
    wandb.finish()


if __name__ == "__main__":
    import os
    from GPT_GNN.config import funsd_home, docvqa_home, cord_home, sroie_home, buddie_home
    from preprocess import FUNSD, DocVQA, CORD, SROIE, BUDDIE

    parser = argparse.ArgumentParser(description='AliGATr finetuner.')
    parser.add_argument('-d','--dataset', type=str, help='Name of datasett to be used for finetuning', choices=['funsd', 'cord', 'sroie', 'buddie'], default='funsd')
    parser.add_argument('-m','--model_type', type=str, help='Type of GNN model used during pretraining', choices=['gengnn', 'mlmgnn'], default='gengnn')
    parser.add_argument('-g','--gpu_num', type=int, help='Number of available GPUs', default=4)
    parser.add_argument('-e','--epochs', type=int, help='Number of epochs', default=1000)
    parser.add_argument('-b','--batch_size', type=int, help='Batch size', default=32)
    parser.add_argument('-w','--num_workers', type=int, help='Number of workers for DataListLoader', default=8)
    parser.add_argument('-i','--input_path', type=str, help='Where to load the pretrained weights from', default='/path/to/model.bin')
    parser.add_argument('-c','--continue_ft', type=bool, help='Whether to continue finetuning from an existing checkpoint', default=False)
    args = vars(parser.parse_args())

    device_ids = list(range(args['gpu_num']))  # [0, 1, 2, 3]
    epochs = args['epochs']
    batch_size = args['batch_size']
    num_workers = args['num_workers']
    path =  args['input_path']
    cont = args['continue_ft']

    datasets = {'funsd': FUNSD(funsd_home),
                'cord': CORD(cord_home),
                'sroie': SROIE(sroie_home),
                'buddie': BUDDIE(buddie_home)}
    dataset =  datasets[args['dataset']]  # FUNSD(funsd_home)

    model_type = args['model_type']

    finetune_and_test(dataset, model_type, path, device_ids, batch_size, epochs, cont=cont)