
import sys
sys.path.append( './' )

import torch
import argparse
from models.Transformers import IRLCBert
from learners.cluster import ClusterLearner
from dataloader.dataloader import augment_loader
from training import training
from utils.randomness import set_global_random_seed
import prettytable as pt
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def run(args):
    set_global_random_seed(args.seed)
    # dataset loader
    train_loader = augment_loader(args)

    table = pt.PrettyTable(
        ["B3_Prec", "B3_Rec", "B3_F1", "V_Hom.", "V_Comp", "V_F1", "ARI","NMI","MEAN"])

    loss_table = pt.PrettyTable(
        ["Instance Loss", "Cluster Loss", "Total Loss"])

    # model
    torch.cuda.set_device(args.gpuid[0])
    model = IRLCBert(args=args)
    model = model.cuda()
    optimizer = torch.optim.Adam([
        {'params':model.transformer.parameters()}, 
        {'params':model.head.parameters(), 'lr': args.lr*args.lr_scale},
        {'params':model.cluster_head.parameters(), 'lr': args.lr*args.lr_scale}], lr=args.lr,weight_decay=3e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=1000, gamma=0.85)
    
    # set up the trainer    
    learner = ClusterLearner(model, optimizer, scheduler)
    training(train_loader, learner, table,loss_table,args)
    return None

def get_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpuid', nargs="+", type=int, default=[1])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--print_freq', type=float, default=250)  
    parser.add_argument('--bert', type=str, default='bert-base-cased')
    # Dataset
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--max_length', type=int, default=96)
    # Learning parameters
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--lr_scale', type=int, default=100)
    parser.add_argument('--max_iter', type=int, default=50000)
    parser.add_argument('--batch_size', type=int, default=32)
    
    args = parser.parse_args(argv)
    args.use_gpu = args.gpuid[0] >= 0

    return args

if __name__ == '__main__':
    run(get_args(sys.argv[1:]))



    
