import numpy as np
import pickle
import random

from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

from Utils.utils import *
from Utils.TSI import Indexing_network

import argparse

def main(args):
    M = args.M
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    early_stop_epochs = args.early_stop_epochs
    data_path = args.data_path
    base_path = args.base_path
    indexing_network_path = base_path + 'CLF_module'

    # Data load
    corpus_emb = torch.load(data_path + '/backbone_corpus_emb.pt').cpu()
    class_emb = torch.load(data_path + '/class_emb.pt').cpu()
    phrase_emb = torch.load(data_path + '/phrase_emb.pt').cpu()

    num_class = class_emb.shape[0]
    num_phrase = phrase_emb.shape[0]

    mlb_class = MultiLabelBinarizer(classes=[i for i in range(num_class)], sparse_output=True)
    mlb_phrase = MultiLabelBinarizer(classes=[i for i in range(num_phrase)], sparse_output=True)

    with open(data_path + '/CSFCube_cid_list', 'rb') as f:
        cid_list = pickle.load(f)
        
    with open(data_path + '/doc2class_dict', 'rb') as f:
        doc2class_dict = pickle.load(f)
        
    with open(data_path + '/doc2phrase_dict', 'rb') as f:
        doc2phrase_dict = pickle.load(f)

    with open(data_path + '/directed_label_graph', 'rb') as f:
        directed_label_graph = pickle.load(f)

    child2parent_dict = {}
    for parent in directed_label_graph:
        for child in directed_label_graph[parent]:
            if child not in child2parent_dict:
                child2parent_dict[child] = [child]
            child2parent_dict[child].append(parent)
            
    with open(data_path + '/re_child2parent_dict', 'rb') as f:
        re_child2parent_dict = pickle.load(f)
        
    for child in re_child2parent_dict:
        re_child2parent_dict[child].append(child)
        
    A = get_adj_mat(num_class, re_child2parent_dict)

    # Training data preparation
    doc_train_X = []
    train_class_Y = []
    train_phrase_Y = []

    train_cid_list = []

    for index in range(len(cid_list)):
        cid = cid_list[index]
        
        if len(doc2class_dict[cid]) == 0: continue
        train_cid_list.append(cid)

        doc_train_X.append(corpus_emb[index])    
        train_class_Y.append(doc2class_dict[cid])
        train_phrase_Y.append(doc2phrase_dict[cid])

    doc_train_X = torch.stack(doc_train_X, 0)

    train_X = torch.cat([doc_train_X, class_emb, phrase_emb], 0)
    raw_class_Y = train_class_Y + [[i] for i in range(class_emb.shape[0])] + [[] for i in range(phrase_emb.shape[0])]
    train_Y = mlb_class.fit_transform(raw_class_Y)

    raw_phrase_Y = train_phrase_Y + [[] for i in range(class_emb.shape[0])] + [[i] for i in range(phrase_emb.shape[0])]
    train_Y2 = mlb_phrase.fit_transform(raw_phrase_Y)

    # train setup
    train_dataset = CLF_dataset(train_X, train_Y, train_Y2)
    train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    CLF_module = Indexing_network(class_emb, phrase_emb, A, M).to('cuda')
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(CLF_module.parameters(), lr=learning_rate, weight_decay=weight_decay)

    class_acc_best, phrase_acc_best = -1, -1
    early_stop = 0

    for epoch in range(500):
        epoch_loss = 0
        for _, mini_batch in enumerate(tqdm(train_loader)):
            
            batch_indices, batch_X = mini_batch
            batch_X = batch_X.to('cuda')
            batch_Y, batch_Y2 = train_loader.dataset.get_labels(batch_indices)
            batch_Y = torch.FloatTensor(batch_Y).to('cuda')
            batch_Y2 = torch.FloatTensor(batch_Y2).to('cuda')
            
            batch_output1, batch_output2, _, _ = CLF_module(batch_X)
            batch_loss1 = criterion(batch_output1, batch_Y)
            batch_loss2 = criterion(batch_output2, batch_Y2)
            
            batch_loss = (batch_loss1 + batch_loss2) / 2.
            
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            epoch_loss += batch_loss.item()
        
        is_improved = False
        
        if epoch % 1 == 0:
            with torch.no_grad():
                CLF_test = CLF_module.eval()
                c_emb_tensor = doc_train_X.to('cuda')
                c_clf_output1, c_clf_output2, _ , _ = CLF_test(c_emb_tensor)
                
                c_top_topics = torch.argsort(-c_clf_output1,1)[:,:20]
                c_top_phrases = torch.argsort(-c_clf_output2,1)[:,:20]

                Class_acc_20s = []
                for idx in range(len(train_cid_list)):
                    acc_20 = len(set(c_top_topics[idx].tolist()) & set(raw_class_Y[idx])) / len(set(raw_class_Y[idx]))
                    Class_acc_20s.append(acc_20)
                class_acc = np.mean(Class_acc_20s)
                
                Phrase_acc_20s = []
                for idx in range(len(train_cid_list)):
                    recall_20 = len(set(c_top_phrases[idx].tolist()) & set(raw_phrase_Y[idx])) / len(set(raw_phrase_Y[idx]))
                    Phrase_acc_20s.append(recall_20)
                Phrase_acc = np.mean(Phrase_acc_20s)
            
            if class_acc_best < class_acc:
                class_acc_best = class_acc
                early_stop = 0
                is_improved = True
                
            if phrase_acc_best < Phrase_acc:
                phrase_acc_best = Phrase_acc
                early_stop = 0
                is_improved = True     
                
            if is_improved == False:
                early_stop += 1
                
            # early stop when converges
            if early_stop >= early_stop_epochs:
                break
            
        print('Loss: {:.3f}'.format(epoch_loss))
        
    torch.save(CLF_module.state_dict(), indexing_network_path)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    parser.add_argument('--base_path', default='./Outputs/', type=str)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--M', type=int, default=3)
    parser.add_argument('--early_stop_epochs', type=int, default=10)
    parser.add_argument('--random_seed', type=int, default=1231)

    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    main(args)

