#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: train.py
#Created Time:
############################
from __future__ import print_function
from sklearn import metrics
from config import CONFIG
import sys
#from __future__ import division
from sklearn import metrics
import random
import time
import sys
import os

import torch
import torch.nn as nn

import numpy as np

from utils.utils import *
#from models.gcn import GCN
from models.wgcn import WGCN
from models.wordgcn import WordGCN
#from models.mlp import MLP
from torch.optim import lr_scheduler
from data import MyData
from data import DataPrefetcher, AverageMeter
from config import CONFIG
from torch.utils.data import DataLoader


datasets = ['20ng', 'R8', 'R52', 'ohsumed', 'mr', 'WebKB','ag_news','dbpedia','yahoo','patent_1sub','patent_f']
cfg = CONFIG()

def compute_logits_and_loss(logits,label,criterion,label_label):

    #logits,logits1,logits2,label_logits = model(features,tfidf)
    loss  = criterion(logits['g'], torch.max(label, 1)[1])
    loss1 = criterion(logits['1'], torch.max(label, 1)[1])
    loss2 = criterion(logits['2'], torch.max(label, 1)[1])
    #loss3 = criterion(logits['3'], torch.max(label, 1)[1])
    label_loss = criterion(logits['lab'], torch.max(label_label, 1)[1])
    #l1_loss = torch.nn.L1Loss()
    #cos_loss1 = l1_loss(logits['lab_cos1'],label_label)
    #cos_loss2 = l1_loss(logits['lab_cos2'],label_label)
    loss = cfg.alpha['g']*loss +\
           cfg.alpha['1']*loss1 +\
           cfg.alpha['2']*loss2 +\
           1.0*label_loss #+ cos_loss1 + cos_loss2
           #cfg.alpha['3']*loss3 +\
    logits = cfg.alpha['g']*torch.softmax(logits['g'],dim=1)+\
             cfg.alpha['1']*torch.softmax(logits['1'],dim=1)+\
             cfg.alpha['2']*torch.softmax(logits['2'],dim=1)
            #cfg.alpha['3']*torch.softmax(logits['3'],dim=1)
    #logits = cfg.alpha['g']*torch.sigmoid(logits['g'])+\
    #         cfg.alpha['1']*torch.sigmoid(logits['1'])+\
    #         cfg.alpha['2']*torch.sigmoid(logits['2'])
    #         #cfg.alpha['3']*torch.sigmoid(logits['3'])
    return logits, loss


def evaluate(valid_loader,model,criterion,label_label,mode='valid'):
    t_test = time.time()
    acc_t = AverageMeter()
    loss_t = AverageMeter()
    preds_list = []
    label_list = []

    # switch to evaluate mode
    model.eval()
    prefetcher = DataPrefetcher(valid_loader)
    tfidf, label = prefetcher.next()
    with torch.no_grad():
        iteration = 0
        embeds = model.wgcn.embeds
        while tfidf is not None:
            iteration +=1

            logits = model.clss(embeds,tfidf)
            logits, loss = compute_logits_and_loss(logits,label,criterion,label_label)
            pred = torch.max(logits, 1)[1]
            acc = ((pred == torch.max(label, 1)[1]).float()).sum().item() / label.shape[0]
            loss_t.update(loss.item(),logits.shape[0])
            acc_t.update(acc,tfidf.shape[0])

            ## accumulate preds and labels
            if mode=="test":
                preds_list = preds_list + pred.cpu().numpy().tolist()
                label_list = label_list + label.cpu().numpy().tolist()
            tfidf, label = prefetcher.next()

        #return loss.numpy(), acc, pred.numpy(), labels.numpy(), (time.time() - t_test)
    return loss_t.avg, acc_t.avg, preds_list, label_list, (time.time()-t_test)



def main():

    ## settings
    seed = 666
    np.random.seed(seed)
    torch.manual_seed(seed)
    if len(sys.argv) != 2:
        sys.exit("User: python train.py <dataset>")

    dataset = sys.argv[1]
    if dataset not in datasets:
        sys.exit("Wrong Dataset Name!")
    cfg.dataset = dataset

    os.environ["CUDA_VISIBLE_DEVICES"]= str(cfg.gpu)

    # load data
    adj, train_tfidf, valid_tfidf, test_tfidf, y_train, y_valid, y_test, train_size, valid_size, test_size, num_classes, pre_vec = load_corpus(cfg.dataset)

    tplt = "{0:>20}\t{1:<10}"
    print("\n\n")
    print("###################### Model & Dataset Description #################")
    print(tplt.format("Dataset Name :", cfg.dataset))
    print(tplt.format("Train Size :", train_size))
    print(tplt.format("Validation Size :", valid_size))
    print(tplt.format("Test  Size :", test_size))
    print(tplt.format("Class Number :", num_classes))
    print(tplt.format("Node Size :", adj.shape[0]))
    print(tplt.format("Model Name :", cfg.model))
    print(tplt.format("Alpha :", ",".join([k+":"+str(v) for k,v in cfg.alpha.items()])))
    print(tplt.format("Learning Rate:", str(cfg.learning_rate)))
    print(tplt.format("Epoch Number:", str(cfg.epochs)))
    print(tplt.format("Embedding Dim:", str(cfg.hidden1)))
    print("--------------------------------------------------------------------")
    print("\n")

    if cfg.model == 'single':
        support = [torch.Tensor(preprocess_adj(adj)).cuda()]
        #support = [torch.Tensor(preprocess_adj(adj)).cuda(),torch.Tensor(preprocess_adj(adj_1)).cuda()]
        model_func = WordGCN
    else:
        raise ValueError("Invalid argument for model: "+ cfg.model)

    # begin to define data

    # features for all words+labels
    #features = sp.identity(adj.shape[0])
    #features = preprocess_features(features)
    #features = torch.from_numpy(features).cuda()

    # define model
    #model = model_func(input_dim=features.shape[0],hidden_dim=cfg.hidden1,support=support, num_classes=num_classes).cuda()

    classes = [num_classes]
    feature_size = adj.shape[0]
    pre_vec=None
    #print(pre_vec)
    model = model_func(input_dim=feature_size,hidden_dim=cfg.hidden1,support=support, classes=classes,pre=pre_vec).cuda()
    
    criterion = nn.CrossEntropyLoss()
    #optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=cfg.learning_rate,alpha=0.9,weight_decay=cfg.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer,gamma=0.9,step_size=cfg.step_size)

    # create data set
    train_dataset = MyData(train_tfidf,y_train)
    valid_dataset = MyData(valid_tfidf,y_valid)
    test_dataset = MyData(test_tfidf,y_test)

    # define data loader
    train_loader = DataLoader(train_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)
    valid_loader = DataLoader(valid_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=cfg.batch_size,num_workers=0,shuffle=True)

    # define label loss
    label_label = torch.eye(num_classes).cuda()
    for epoch in range(cfg.epochs):
        t = time.time()
        prefetcher = DataPrefetcher(train_loader)

        # Firstly, run gcn on word net, get the embeddings of words
        #embeds = model.wgcn(None)

        # run training
        tfidf, label = prefetcher.next()
        iteration = 0
        while tfidf is not None:
            iteration +=1

            #embeds = model.wgcn(features)
            embeds = model.wgcn(None)
            logits = model.clss(embeds,tfidf)
            logits, loss = compute_logits_and_loss(logits,label,criterion,label_label)
            acc = ((torch.max(logits, 1)[1] == torch.max(label, 1)[1]).float()).sum().item() / tfidf.shape[0]
            tfidf, label = prefetcher.next()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()
        if epoch % 10 == 0:
            test_loss, test_acc, test_pred, labels, duration = evaluate(test_loader,model,criterion,label_label,'test')
            print(test_acc)


        val_loss, val_acc, pred, labels, duration = evaluate(valid_loader,model,criterion,label_label)
        print_log("Epoch: {:.0f}, train_loss= {:.5f}, train_acc= {:.5f}, val_loss= {:.5f}, val_acc= {:.5f}, time= {:.5f}".format(epoch + 1, loss, acc, val_loss, val_acc, time.time() - t))

    # Begin to test
    test_loss, test_acc, test_pred, labels, duration = evaluate(test_loader,model,criterion,label_label,'test')
    test_labels = []

    for i in range(len(test_pred)):
        test_labels.append(np.argmax(np.array(labels[i])))

    print_log("Test Precision, Recall and F1-Score...")
    print_log(metrics.classification_report(test_labels, test_pred, digits=4))
    print_log("Macro average Test Precision, Recall and F1-Score...")
    print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='macro'))
    print_log("Micro average Test Precision, Recall and F1-Score...")
    print_log(metrics.precision_recall_fscore_support(test_labels, test_pred, average='micro'))



if __name__=="__main__":
    main()
