#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: train.py
#Created Time:
############################
from __future__ import print_function
from sklearn import metrics
from config import CONFIG
import sys
#from __future__ import division
from sklearn import metrics
import random
import time
import sys
import os

import torch
import torch.nn as nn

import numpy as np

from utils.utils import *
#from models.gcn import GCN
from models.wgcn import WGCN
from models.wordgcn import WordGCN
#from models.mlp import MLP
from torch.optim import lr_scheduler
from data import MyData
from data import DataPrefetcher, AverageMeter
from hmc_config import CONFIG
from torch.utils.data import DataLoader
from losses import losses

from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from utils.eval import *
from torch.optim import lr_scheduler
import torch.nn.functional as F

#datasets = ['20ng', 'R8', 'R52', 'ohsumed', 'mr', 'WebKB','ag_news','dbpedia','yahoo']
datasets = ['patent_1sub','patent_f','patent_1sub_abs','patent_2gro','aapd_h','wos','eurlex','patent_fm_10','patent_fm_20','patent_fm_50']
single_models = ['']
hmc_models = ['hmc','supercls','mhmc']

cfg = CONFIG()

def evaluate(compute_loss_and_logits,valid_loader,model,criterion,label_label,classes,mode='valid',cls='base'):
    t_test = time.time()
    acc_t = AverageMeter()
    loss_t = AverageMeter()
    preds_list = []
    label_list = []

    # switch to evaluate mode
    model.eval()
    prefetcher = DataPrefetcher(valid_loader)
    tfidf, label = prefetcher.next()
    with torch.no_grad():
        iteration = 0
        #embeds = model.wgcn.embeds
        embeds = model.module.wgcn.embeds
        while tfidf is not None:
            iteration +=1
            #logits = model.clss(embeds,tfidf)
            logits = model.module.clss(embeds,tfidf)
            #if cfg.model=='hmc' or cfg.model=='mhmc':
            logits, loss = compute_loss_and_logits(logits,label,criterion,label_label,classes,cls=cls)

            pred = torch.max(logits, 1)[1]
            acc = ((pred == torch.max(label, 1)[1]).float()).sum().item() / label.shape[0]
            loss_t.update(loss.item(),logits.shape[0])
            acc_t.update(acc,tfidf.shape[0])

            ## accumulate preds and labels
            if mode=="test":
                preds_list = preds_list + pred.cpu().numpy().tolist()
                label_list = label_list + label.cpu().numpy().tolist()
            tfidf, label = prefetcher.next()

        #return loss.numpy(), acc, pred.numpy(), labels.numpy(), (time.time() - t_test)
    return loss_t.avg, acc_t.avg, preds_list, label_list, (time.time()-t_test)


def evaluate_hmc(compute_loss_and_logits,valid_loader,model,criterion,label_label,classes,mode='test',thres=0.5,cls='base'):
    t_test = time.time()
    #acc_t = AverageMeter()
    loss_t = AverageMeter()
    preds_list = []
    label_list = []

    # switch to evaluate mode
    model.eval()
    prefetcher = DataPrefetcher(valid_loader)
    tfidf, label = prefetcher.next()
    with torch.no_grad():
        iteration = 0
        #embeds = model.wgcn.embeds
        embeds = model.module.wgcn.embeds
        while tfidf is not None:
            iteration +=1
            #logits = model.clss(embeds,tfidf)
            logits = model.module.clss(embeds,tfidf)
            logits, loss = compute_loss_and_logits(logits,label,criterion,label_label,classes=classes,cls=cls)

            loss_t.update(loss.item(),logits.shape[0])

            ## accumulate preds and labels
            if mode=="test":
                preds_list = preds_list + logits.cpu().numpy().tolist()
                label_list = label_list + label.cpu().numpy().tolist()
            tfidf, label = prefetcher.next()


    # evaluate results and print results

    label_list = np.array(label_list)
    pred_scores = []
    pred_labels_thes = []
    pred_labels_topk = []
    pred_labels_thes = get_label_by_threshold(scores=preds_list,threshold=thres)
    pred_labels_thes = np.array(pred_labels_thes)
    #pred_labels_topk = get_label_by_topk(scores=preds_list,threshold=0.5)
    precision_ts = precision_score(y_true=label_list,y_pred=pred_labels_thes,average='micro')
    recall_ts    = recall_score(   y_true=label_list,y_pred=pred_labels_thes,average='micro')
    fscore_ts    = f1_score(       y_true=label_list,y_pred=pred_labels_thes,average='micro')
    print("\n-----------------------------Test---------------------------------------")
    print_log("Evaluate through Threshold: Test Presion={:,.5f}, Recall={:,.5f}, F1_score={:,.5f}".format(precision_ts,recall_ts,fscore_ts))

    #with open("./outs/res_1sub","a+") as f:
    #    for i in range(label_list.shape[0]):
    #        l1 = []
    #        l2 = []
    #        for j in range(label_list.shape[1]):
    #            if label_list[i,j]>0:
    #                l1.append(j)
    #            if pred_labels_thes[i,j]>0:
    #                l2.append(j)
    #        f.write(",".join(map(str,l1))+"\t|\t"+",".join(map(str,l2))+'\n')
    with open("./outs/patent_drop10","a+") as f:
        f.write(str(precision_ts) + "\t" + str(recall_ts)+"\t"+str(fscore_ts)+'\n')


    # evalute results through topK
    for top_num in range(5):
        #pred_labels_tk.append(get_label_topk(scores=pred_list,top_num=top_num+1))
        lab_tk = get_label_by_topk(scores=preds_list,top_num=top_num+1)
        pre = precision_score(y_true=label_list,y_pred=lab_tk,average='micro')
        rec = recall_score(y_true=label_list,y_pred=lab_tk,average='micro')
        fsc = f1_score(y_true=label_list,y_pred=lab_tk,average='micro')
        print_log("Top " + str(top_num+1) + ": Test Presion={:,.5f}, Recall={:,.5f}, F1_score={:,.5f}".format(pre,rec,fsc))
    print("------------------------------------------------------------------------\n")


def compute_loss_and_logits(logits,label,criterion,label_label,classes,cls="supercls",step=10):

    l_0 = torch.eye(classes[0]).cuda()
    l_1 = torch.eye(classes[1]-classes[0]).cuda()
    l_2 = torch.eye(classes[2]-classes[1]).cuda()
    l_3 = torch.eye(classes[3]-classes[2]).cuda()
    if cls == 'base' or cls == 'labelattention':
        # wgcn model
        lossg = criterion(logits['g'], label.float())
        logits = torch.sigmoid(logits['g'])
        #print(logits)
        return logits, lossg
    #elif cls == "hmc"or cls =='supercls':
    elif cls in hmc_models:
        loss0 = criterion(logits['0'],           label[:,:classes[0]].float().matmul(logits['c_0']))
        loss1 = criterion(logits['1'], label[:,classes[0]:classes[1]].float().matmul(logits['c_1']))
        loss2 = criterion(logits['2'], label[:,classes[1]:classes[2]].float().matmul(logits['c_2']))
        loss3 = criterion(logits['3'], label[:,classes[2]:classes[3]].float().matmul(logits['c_3']))

        #loss0 = criterion(logits['0'],           label[:,:classes[0]].float())
        #loss1 = criterion(logits['1'], label[:,classes[0]:classes[1]].float())
        #loss2 = criterion(logits['2'], label[:,classes[1]:classes[2]].float())
        #loss3 = criterion(logits['3'], label[:,classes[2]:classes[3]].float())

        lossg = criterion(logits['g'], label.float())
        lab_loss0 = criterion(logits['l_0'], l_0.float())
        lab_loss1 = criterion(logits['l_1'], l_1.float())
        lab_loss2 = criterion(logits['l_2'], l_2.float())
        lab_loss3 = criterion(logits['l_3'], l_3.float())
        #loss_sc1 = criterion(logits['sc1'], label[:,:classes[0]].float())
        #loss_sc2 = criterion(logits['sc2'], label[:,classes[0]:classes[1]].float())
        #loss3 = criterion(logits[2], label[:,classes[1]:classes[2]].float())
        #loss4 = criterion(logits[3], label[:,classes[2]:].float())
        #lab_loss = criterion(logits['lab_loss'], label_label.float())

        #loss = loss1 + loss2 + loss3 + loss4 + lossg

        loss =  cfg.alpha['0']  * loss0 +\
                cfg.alpha['1']  * loss1 +\
                cfg.alpha['2']  * loss2 +\
                cfg.alpha['3']  * loss3 +\
                cfg.alpha['g']  * lossg +\
                0.1*(lab_loss0 + lab_loss1 + lab_loss2 + lab_loss3)
                #cfg.alpha['sc2']*loss_sc2 +\
                #cfg.alpha['sc1']*loss_sc1 +\
                #lab_loss
        #logits = 0.5*torch.cat((torch.sigmoid(logits['1']),torch.sigmoid(logits['2']),torch.sigmoid(logits[2]),torch.sigmoid(logits[3])),dim=1)+ 1.0*torch.sigmoid(logits[-1])
        if False:
            logits = 1*torch.cat((torch.sigmoid(logits['0']),torch.sigmoid(logits['1'])),dim=1)+ 1.0*torch.sigmoid(logits['g'])
        else :
            #logits = 0.5*torch.cat(
            #((cfg.alpha['1']  *torch.sigmoid(logits['1'])+cfg.alpha['sc1']*torch.sigmoid(logits['sc1'])),
            # (cfg.alpha['2']  *torch.sigmoid(logits['2'])+cfg.alpha['sc2']*torch.sigmoid(logits['sc2']))),dim=1)+\
            #     0.5*cfg.alpha['g']*torch.sigmoid(logits['g'])
            logits = 1.0*torch.cat((
                        cfg.alpha['0'] * torch.sigmoid(logits['0']),
                        cfg.alpha['1'] * torch.sigmoid(logits['1']),
                        cfg.alpha['2'] * torch.sigmoid(logits['2']),
                        cfg.alpha['3'] * torch.sigmoid(logits['3'])),dim=1)+\
                     1.0*cfg.alpha['g']*torch.sigmoid(logits['g'])
        return logits, loss


def _compute_loss_and_logits(logits,label,criterion,label_label,classes,cls="supercls",step=10):
    
    label_0 = torch.eye(classes[0]).cuda()
    label_1 = torch.eye(classes[1]-classes[0]).cuda()
    label_2 = torch.eye(classes[2]-classes[1]).cuda()
    label_g = torch.eye(classes[-1]).cuda()
    if cls == 'base' or cls == 'labelattention':
        # wgcn model
        lossg = criterion(logits['g'], label.float())
        logits = torch.sigmoid(logits['g'])
        return logits, lossg
    elif cls in hmc_models:
        # wgcn + hmc
        if step>1000:
            loss0 = criterion(logits['0'], label[:,:classes[0]].float().matmul(logits['corm0']))
            loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float().matmul(logits['corm1']))
        else:
            #loss0 = criterion(logits['0'], label[:,:classes[0]].float().matmul(logits['corm0']))
            #loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float().matmul(logits['corm1']))
            #loss2 = criterion(logits['2'], label[:, classes[1]:classes[2]].float().matmul(logits['corm2']))
            loss0 = criterion(logits['0'], label[:,:classes[0]].float())#.matmul(logits['corm0']))
            loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float())
            loss2 = criterion(logits['2'], label[:, classes[1]:classes[2]].float())

        #label_g = torch.cat((label[:,:classes[0]].float().matmul(logits["corm0"]),
        #                     label[:,classes[0]:classes[1]].float().matmul(logits["corm1"]),
        #                     label[:,classes[1]:classes[2]].float().matmul(logits["corm2"])),dim=1)
        lossg = criterion(logits['g'], label.float())
        #lossg = criterion(logits['g'], label_g)
        #loss0 = criterion(logits['0'], 0.5*label[:,:classes[0]].float()+0.5*label[:,:classes[0]].float().matmul(logits['corm0']))
        #loss1 = criterion(logits['1'], 0.5*label[:,classes[0]:classes[1]].float()+0.5*label[:,classes[0]:classes[1]].float().matmul(logits['corm1']))
        #lossg = criterion(logits['g'], label.float())
        #lossg = criterion(logits['g'], label.float().matmul(logits['cormg']))
        #nuc_0 = torch.norm(logits['corm0'],p=)
        #nuc_1 = torch.norm(logits['corm1'],p=1)
        #nuc_0 = logits['corm0'].norm(p='nuc')
        #nuc_1 = logits['corm1'].norm(p='nuc')
        #loss_sc0 = criterion(logits['sc0'], label[:,:classes[0]].float().matmul(logits['corm0']))
        #loss_sc1 = criterion(logits['sc1'], label[:,classes[0]:classes[1]].float().matmul(logits['corm1']))
        #loss3 = criterion(logits[2], label[:,classes[1]:classes[2]].float())
        #loss4 = criterion(logits[3], label[:,classes[2]:].float())
        lab_loss0 = criterion(logits['lab_loss0'], label_0.float())
        lab_loss1 = criterion(logits['lab_loss1'], label_1.float())
        lab_loss2 = criterion(logits['lab_loss2'], label_2.float())
        #lab_lossg = criterion(logits['lab_lossg'], label_g.float())
        #print("l_0: {}, l_1: {}, l_g: {}, div: {}".format(loss0,loss1,lossg,logits['div_loss']))

        #loss = loss1 + loss2 + loss3 + loss4 + lossg

        loss =  cfg.alpha['0']  * loss0 +\
                cfg.alpha['1']  * loss1 +\
                cfg.alpha['1']  * loss2 +\
                cfg.alpha['g']  * lossg +\
                lab_loss0 +\
                lab_loss1 +\
                lab_loss2
                #logits['div_loss']
                #cfg.alpha['sc0']*loss_sc0 +\
                #cfg.alpha['sc1']*loss_sc1
                #lab_lossg
                #1e-3*nuc_1
                #+nuc_1
                #lab_loss
        #logits = 0.5*torch.cat((torch.sigmoid(logits['1']),torch.sigmoid(logits['2']),torch.sigmoid(logits[2]),torch.sigmoid(logits[3])),dim=1)+ 1.0*torch.sigmoid(logits[-1])
        if True:
            #print("----")
            #print(logits['sc0'])
            #print(logits['sc1'])
            #print(logits['g'])
            #print("-------------")
            #print(loss0)
            #print(loss1)
            #print(lossg)
            #print(lab_loss0)
            #print(lab_loss1)
            #print(torch.sigmoid(logits['0']))
            #print(torch.sigmoid(logits['1']))
            #print(torch.sigmoid(logits['2']))
            logits = 1.0*torch.cat((torch.sigmoid(logits['0']),
                                    torch.sigmoid(logits['1']),
                                    torch.sigmoid(logits['2'])),dim=1)+ 1.0*torch.sigmoid(logits['g'])
        else :
            logits = 0.5*torch.cat(
            ((cfg.alpha['0']  *torch.sigmoid(logits['0'])+cfg.alpha['sc0']*torch.sigmoid(logits['sc0'])),
             (cfg.alpha['1']  *torch.sigmoid(logits['1'])+cfg.alpha['sc1']*torch.sigmoid(logits['sc1']))),dim=1)+\
                 cfg.alpha['g']*torch.sigmoid(logits['g'])
            #logits = 0.5*torch.cat((
            #            cfg.alpha['0']  *torch.sigmoid(logits['0']),
            #            cfg.alpha['1']  *torch.sigmoid(logits['1']),
            #            cfg.alpha['2']  *torch.sigmoid(logits['2']),
            #            cfg.alpha['3']  *torch.sigmoid(logits['3'])),dim=1)+\
            #         0.5*cfg.alpha['g']*torch.sigmoid(logits['g'])
        #print(logits)
        return logits, loss


def _compute_loss_and_logits(logits,label,criterion,label_label,classes,cls="supercls",step=10):
    
    label_0 = torch.eye(classes[0]).cuda()
    label_1 = torch.eye(classes[1]-classes[0]).cuda()
    label_g = torch.eye(classes[-1]).cuda()
    if cls == 'base' or cls == 'labelattention':
        lossg = criterion(logits['g'], label.float())
        logits = torch.sigmoid(logits['g'])
        return logits, lossg
    elif cls in hmc_models:
        if step>100000:
            loss0 = criterion(logits['0'], label[:,:classes[0]].float().matmul(logits['corm0']))
            #loss0 = criterion(logits['0'], label[:,:classes[0]].float())#.matmul(logits['corm0']))
            loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float().matmul(logits['corm1']))
        else:
            loss0 = criterion(logits['0'], label[:,:classes[0]].float().matmul(logits['corm0']))
            loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float().matmul(logits['corm1']))
            cor = logits['corm0'].detach().cpu().numpy()
            np.savetxt("cor.txt",cor)
            #loss0 = criterion(logits['0'], label[:,:classes[0]].float())#.matmul(logits['corm0']))
            #loss1 = criterion(logits['1'], label[:, classes[0]:classes[1]].float())#.matmul(logits['corm1']))

        lossg = criterion(logits['g'], label.float())
        #loss0 = criterion(logits['0'], 0.5*label[:,:classes[0]].float()+0.5*label[:,:classes[0]].float().matmul(logits['corm0']))
        #loss1 = criterion(logits['1'], 0.5*label[:,classes[0]:classes[1]].float()+0.5*label[:,classes[0]:classes[1]].float().matmul(logits['corm1']))
        #lossg = criterion(logits['g'], label.float())
        #lossg = criterion(logits['g'], label.float().matmul(logits['cormg']))
        #nuc_0 = torch.norm(logits['corm0'],p=)
        #nuc_1 = torch.norm(logits['corm1'],p=1)
        #nuc_0 = logits['corm0'].norm(p='nuc')
        #nuc_1 = logits['corm1'].norm(p='nuc')
        #loss_sc0 = criterion(logits['sc0'], label[:,:classes[0]].float().matmul(logits['corm0']))
        #loss_sc1 = criterion(logits['sc1'], label[:,classes[0]:classes[1]].float().matmul(logits['corm1']))
        #loss3 = criterion(logits[2], label[:,classes[1]:classes[2]].float())
        #loss4 = criterion(logits[3], label[:,classes[2]:].float())
        lab_loss0 = criterion(logits['lab_loss0'], label_0.float())
        lab_loss1 = criterion(logits['lab_loss1'], label_1.float())
        #lab_lossg = criterion(logits['lab_lossg'], label_g.float())
        #print("l_0: {}, l_1: {}, l_g: {}, div: {}".format(loss0,loss1,lossg,logits['div_loss']))

        #loss = loss1 + loss2 + loss3 + loss4 + lossg

        loss =  1.0*cfg.alpha['0']  * loss0 +\
                1.0*cfg.alpha['1']  * loss1 +\
                1.0*cfg.alpha['g']  * lossg +\
                1.0*lab_loss0 +\
                1.0*lab_loss1# +\
                #logits['div_loss']
                #cfg.alpha['sc0']*loss_sc0 +\
                #cfg.alpha['sc1']*loss_sc1
                #lab_lossg
                #1e-3*nuc_1
                #+nuc_1
                #lab_loss
        #logits = 0.5*torch.cat((torch.sigmoid(logits['1']),torch.sigmoid(logits['2']),torch.sigmoid(logits[2]),torch.sigmoid(logits[3])),dim=1)+ 1.0*torch.sigmoid(logits[-1])
        if True:
            #print("----")
            #print(logits['sc0'])
            #print(logits['sc1'])
            #print(logits['g'])
            #print("-------------")
            #print(loss0)
            #print(loss1)
            #print(lossg)
            #print(lab_loss0)
            #print(lab_loss1)
            #print(torch.sigmoid(logits['0']))
            #print(torch.sigmoid(logits['1']))
            #print(torch.sigmoid(logits['2']))
            logits = 1.0*torch.cat((torch.sigmoid(logits['0']),
                                    torch.sigmoid(logits['1'])),dim=1)+ 1.0*torch.sigmoid(logits['g'])
        else :
            logits = 0.5*torch.cat(
            ((cfg.alpha['0']  *torch.sigmoid(logits['0'])+cfg.alpha['sc0']*torch.sigmoid(logits['sc0'])),
             (cfg.alpha['1']  *torch.sigmoid(logits['1'])+cfg.alpha['sc1']*torch.sigmoid(logits['sc1']))),dim=1)+\
                 cfg.alpha['g']*torch.sigmoid(logits['g'])
            #logits = 0.5*torch.cat((
            #            cfg.alpha['0']  *torch.sigmoid(logits['0']),
            #            cfg.alpha['1']  *torch.sigmoid(logits['1']),
            #            cfg.alpha['2']  *torch.sigmoid(logits['2']),
            #            cfg.alpha['3']  *torch.sigmoid(logits['3'])),dim=1)+\
            #         0.5*cfg.alpha['g']*torch.sigmoid(logits['g'])
        #print(logits)
        return logits, loss





def adjust_learning_rate(optimizer,lr):
    print("adjust learning rate {}".format(lr))
    lr *= 0.8
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    print("after adjusting learning rate {}".format(lr))
    return lr


def main():

    ## settings
    seed = 666
    np.random.seed(seed)
    torch.manual_seed(seed)
    if len(sys.argv) != 2:
        sys.exit("User: python train.py <dataset>")

    dataset = sys.argv[1]
    if dataset not in datasets:
        sys.exit("Wrong Dataset Name!")
    cfg.dataset = dataset

    os.environ["CUDA_VISIBLE_DEVICES"]= str(cfg.gpu)
    #os.environ["CUDA_VISIBLE_DEVICES"]= "0,1,2"

    # load data
    adj, train_tfidf, valid_tfidf, test_tfidf, y_train, y_valid, y_test, train_size, valid_size, test_size, num_classes, pre_vec = load_corpus(cfg.dataset)

    # classes
    #classes = [9,128,661,8364]
    #classes = [9,137,798,9162]
    classes = cfg.classes
    #print(classes)

    tplt = "{0:>30}\t{1:<10}"
    print("\n\n")
    print("###################### Model & Dataset Description #################")
    print(tplt.format("Dataset Name :", cfg.dataset))
    print(tplt.format("Train Size :", train_size))
    print(tplt.format("Validation Size :", valid_size))
    print(tplt.format("Test  Size :", test_size))
    print(tplt.format("Class Number :", num_classes))
    print(tplt.format("Level Class Number :", ",".join(map(str,classes))))
    print(tplt.format("Node Size :", adj.shape[0]))
    print(tplt.format("Model Name :", cfg.model))
    print(tplt.format("Learning Rate :", str(cfg.learning_rate)))
    print(tplt.format("Epoch Number :", str(cfg.epochs)))
    print(tplt.format("Embedding Dim :", str(cfg.hidden1)))
    print("--------------------------------------------------------------------")
    print("\n")

    #if cfg.model == 'wgcn':
    if True:
        support = [torch.Tensor(preprocess_adj(adj)).cuda()]
        model_func = WordGCN
    else:
        raise ValueError("Invalid argument for model: "+ cfg.model)

    # create layer-wise mask
    node_size = adj.shape[0]
    vocab_size = node_size - classes[-1]
    mask_1 = torch.ones((node_size,node_size))
    mask_1[vocab_size+classes[0]:vocab_size+classes[1],:]=0
    mask_1[:,vocab_size+classes[0]:vocab_size+classes[1]]=0
    #for i in range(valid_size,vocab_size+classes[0]):
    #    mask_1[i,i] = 1

    mask_2 = torch.ones((node_size,node_size))
    mask_2[vocab_size:vocab_size+classes[0],:]=0
    mask_2[:,vocab_size:vocab_size+classes[0]]=0
    #for i in range(valid_size+classes[0],vocab_size+classes[1]):
    #    mask_2[i,i] = 1
    masks = [mask_1,mask_2]
    masks = None

    # begin to define data

    # features for all words+labels
    #features = sp.identity(adj.shape[0])
    #features = preprocess_features(features)
    #features = torch.from_numpy(features).cuda()

    # define model
    #model = model_func(input_dim=features.shape[0],hidden_dim=cfg.hidden1,support=support, num_classes=num_classes).cuda()
    feature_size = adj.shape[0]
    #pre_vec = None
    #print(pre_vec)
    #masks = None

    model = model_func(input_dim=feature_size,hidden_dim=cfg.hidden1,support=support, classes=classes,cls=cfg.model, pre=pre_vec,masks=masks)

    #if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model = model.cuda()
    #criterion = nn.CrossEntropyLoss()

    #optimizer = torch.optim.RMSprop(model.parameters(), lr=cfg.learning_rate,alpha=0.9,weight_decay=cfg.weight_decay)
    criterion = nn.BCEWithLogitsLoss()
    #criterion = nn.MSELoss()
    #criterion = nn.MultiLabelSoftMarginLoss()
    #criterion = losses.KLLoss
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer,gamma=0.8,step_size=cfg.step_size)

    ## initialize parameters
    #for p in model.parameters():
    #    if p.dim()>1:
    #        nn.init.xavier_uniform_(p)
            #nn.init.xavier_normal_(p)
            #nn.init.kaiming_uniform_(p,a=0,mode='fan_in',nonlinearity='leaky_relu')

    # create data set
    train_dataset = MyData(train_tfidf,y_train)
    valid_dataset = MyData(valid_tfidf,y_valid)
    test_dataset = MyData(test_tfidf,y_test)

    # define data loader
    train_loader = DataLoader(train_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)
    valid_loader = DataLoader(valid_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)

    # define label loss
    label_label = torch.eye(num_classes).cuda()
    lr = cfg.learning_rate
    for epoch in range(cfg.epochs):
        #if (epoch+1)%50==0:
        #    lr = adjust_learning_rate(optimizer,lr)
        model.train()
        t = time.time()
        prefetcher = DataPrefetcher(train_loader)

        # Firstly, run gcn on word net, get the embeddings of words
        #embeds = model.wgcn(features)

        # run training
        tfidf, label = prefetcher.next()
        iteration = 0
        while tfidf is not None:
            iteration +=1

            # compute the embeddings of the words
            if epoch < 10000:
                #embeds = model.wgcn(None)
                embeds = model.module.wgcn(None)
            else:
                #embeds = [v.detach() for v in model.wgcn.embeds]
                embeds = [v.detach() for v in model.module.wgcn.embeds]
            #logits = model.clss(embeds,tfidf)
            logits = model.module.clss(embeds,tfidf,label,'test')

            logits, loss = compute_loss_and_logits(logits,label,criterion,label_label,classes=classes,cls=cfg.model,step=epoch)

            acc = ((torch.max(logits, 1)[1] == torch.max(label, 1)[1]).float()).sum().item() / tfidf.shape[0]

            tfidf, label = prefetcher.next()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()

        val_loss, val_acc, pred, labels, duration = evaluate(compute_loss_and_logits,valid_loader,model,criterion,label_label,classes,cls=cfg.model)
        print_log("Epoch: {:.0f}, train_loss= {:.5f}, train_acc= {:.5f}, val_loss= {:.5f}, val_acc= {:.5f}, time= {:.5f}".format(epoch + 1, loss, acc, val_loss, val_acc, time.time() - t))

        if epoch%10==0:
            evaluate_hmc(compute_loss_and_logits,test_loader,model,criterion,label_label,classes,'test',cfg.threshold,cfg.model)


    # Begin to test
    evaluate_hmc(compute_loss_and_logits,test_loader,model,criterion,label_label,classes,'test',cfg.threshold)
    #test_loss, test_acc, test_pred, labels, duration = evaluate(test_loader,model,criterion,label_label,'test')

    #test_labels = []

    #for i in range(len(test_pred)):
    #    test_labels.append(np.argmax(np.array(labels[i])))

    #print_log("Test Precision, Recall and F1-Score...")
    #print_log(metrics.classification_report(test_labels, test_pred, digits=4))
    #print_log("Macro average Test Precision, Recall and F1-Score...")
    #print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='macro'))
    #print_log("Micro average Test Precision, Recall and F1-Score...")
    #print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='micro'))



if __name__=="__main__":
    main()
