import numpy as np
import pickle
import random

from tqdm import tqdm

import torch
import torch.utils.data as data
import torch.optim as optim

from Utils.utils import *
from Utils.TSI import TSI_module

from beir.retrieval.evaluation import EvaluateRetrieval

import argparse

def main(args):
    M = args.M
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    early_stop_epochs = args.early_stop_epochs
    data_path = args.data_path
    base_path = args.base_path
    lambda_IL = args.lambda_IL
    D_d_size = args.D_d_size
    neg_numbers = args.neg_numbers
    indexing_network_path = base_path + 'CLF_module'
    TSI_module_path = base_path + 'TSI_module'

    # Data load
    corpus_emb = torch.load(data_path + '/backbone_corpus_emb.pt').cpu()
    class_emb = torch.load(data_path + '/class_emb.pt').cpu()
    phrase_emb = torch.load(data_path + '/phrase_emb.pt').cpu()

    num_class = class_emb.shape[0]
    num_phrase = phrase_emb.shape[0]

    with open(data_path + '/CSFCube_qid_list', 'rb') as f:
        qid_list = pickle.load(f)

    with open(data_path + '/CSFCube_cid_list', 'rb') as f:
        cid_list = pickle.load(f)
        
    with open(data_path + '/doc2class_dict', 'rb') as f:
        doc2class_dict = pickle.load(f)
        
    with open(data_path + '/doc2phrase_dict', 'rb') as f:
        doc2phrase_dict = pickle.load(f)

    with open(data_path + '/directed_label_graph', 'rb') as f:
        directed_label_graph = pickle.load(f)

    child2parent_dict = {}
    for parent in directed_label_graph:
        for child in directed_label_graph[parent]:
            if child not in child2parent_dict:
                child2parent_dict[child] = [child]
            child2parent_dict[child].append(parent)
            
    with open(data_path + '/re_child2parent_dict', 'rb') as f:
        re_child2parent_dict = pickle.load(f)
        
    for child in re_child2parent_dict:
        re_child2parent_dict[child].append(child)
        
    A = get_adj_mat(num_class, re_child2parent_dict)

    # Training data preparation
    with open(data_path + '/CSFCube_training_qid_list', 'rb') as f:
        total_queries_list = pickle.load(f)
    total_queries_emb = torch.load(data_path + '/backbone_training_queries_emb.pt').cpu()    

    training_queries_list, valid_queries_list = total_queries_list[:-(len(total_queries_list) // 10)], total_queries_list[-(len(total_queries_list) // 10):]
    training_queries_emb, valid_queries_emb = total_queries_emb[:-(len(total_queries_list) // 10)], total_queries_emb[-(len(total_queries_list) // 10):]

    with open(data_path + '/Promptgator_training_queries', 'rb') as f:
        total_queries = pickle.load(f)

    with open('./Outputs/CSFCube_training_BM25_results', 'rb') as f:
        BM25_top500_results = pickle.load(f)

    training_queries = {}
    for qid in training_queries_list:
        training_queries[qid] = total_queries[qid]

    valid_queries = {}
    for qid in valid_queries_list:
        valid_queries[qid] = total_queries[qid]

    # Core topic aware negative mining
    doc_class_mat = torch.zeros((len(doc2class_dict), num_class))

    for idx, cid in enumerate(cid_list):
        for topic_id in doc2class_dict[cid]:
            doc_class_mat[idx][topic_id] = 1
            
    doc_score_mat = torch.matmul(doc_class_mat, doc_class_mat.T)
    doc_hardneg_mat = torch.argsort(-doc_score_mat, axis=-1)[:,:D_d_size]

    doc2hardnegdoc_dic = {}
    for idx, cid in enumerate(cid_list):
        doc2hardnegdoc_dic[cid] = [cid_list[x] for x in doc_hardneg_mat[idx]]
        
    topic_results = {}
    for idx, qid in enumerate(total_queries_list):
        relevant_cid = qid.split('_')[0]
        topic_results[qid] = doc2hardnegdoc_dic[relevant_cid]

    diff_dict = {}
    for qid in BM25_top500_results:
        diff_dict[qid] = {}
        for pid in topic_results[qid]:
            if pid in BM25_top500_results[qid]:
                diff_dict[qid][pid] = BM25_top500_results[qid][pid]
            else:
                diff_dict[qid][pid] =  -1

    diff_dict = return_topK_result(diff_dict, neg_numbers)            
    
    ## train set preparation
    trainig_qrels = {}
    for qid in training_queries:
        cid, idx = qid.split('_')
        trainig_qrels[qid] = {cid: 1}
        
    train_triplets = {}

    for qid in trainig_qrels:
        if qid not in diff_dict: continue
        pos_pids = set(trainig_qrels[qid].keys())
        neg_pids = set(diff_dict[qid].keys())
        neg_pids = neg_pids - pos_pids
        if len(pos_pids) == 0: continue
        if len(neg_pids) == 0: continue

        train_triplets[qid] = {'query': qid, 'pos': list(pos_pids), 'hard_neg': list(neg_pids)}

    print("Train triplets: {}".format(len(train_triplets)))

    ## valid set preparation
    valid_qrels = {}
    for qid in valid_queries:
        cid, idx = qid.split('_')
        valid_qrels[qid] = {cid: 1}
        
    valid_triplets = {}

    for qid in valid_qrels:
        pos_pids = set(valid_qrels[qid].keys())
        neg_pids = set(diff_dict[qid].keys())
        neg_pids = neg_pids - pos_pids
        if len(pos_pids) == 0: continue
        if len(neg_pids) == 0: continue

        valid_triplets[qid] = {'query': qid, 'pos': list(pos_pids), 'hard_neg': list(neg_pids)}

    print("Valid triplets: {}".format(len(valid_triplets)))

    # Training preparation
    model = TSI_module(class_emb, phrase_emb, A, M, indexing_network_path, lambda_IL).to('cuda')
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    corpus_emb, queries_emb = corpus_emb.to('cuda'), queries_emb.to('cuda')

    train_dataset = TripletDataset(train_triplets, doc2class_dict, doc2phrase_dict, \
                                training_queries_emb, training_queries_list, corpus_emb, cid_list, num_class, num_phrase)
    train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    valid_dataset = TripletDataset(valid_triplets, doc2class_dict, doc2phrase_dict, \
                                valid_queries_emb, valid_queries_list, corpus_emb, cid_list, num_class, num_phrase)
    valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

    best_valid_loss = 9999999
    early_stop = 0

    for epoch in range(50):
        print('Epoch:', epoch)
        epoch_loss = 0
        for _, mini_batch in enumerate(tqdm(train_loader)):
            
            batch_qid, batch_pos_id, batch_neg_id, batch_class_label_vec, batch_phrase_label_vec = mini_batch
            q_emb = train_loader.dataset.qid2emb(batch_qid)
            d_emb = train_loader.dataset.cid2emb(batch_pos_id)
            neg_d_emb = train_loader.dataset.cid2emb(batch_neg_id)
            
            q_emb, d_emb, neg_d_emb = q_emb.to('cuda'), d_emb.to('cuda'), neg_d_emb.to('cuda')
            batch_Y = torch.FloatTensor(batch_class_label_vec).to('cuda')
            batch_Y2 = torch.FloatTensor(batch_phrase_label_vec).to('cuda')        
            
            batch_loss = model.get_fine_tune_loss(q_emb, d_emb, neg_d_emb, batch_Y, batch_Y2)
            
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            epoch_loss += batch_loss.item()

        is_improved = False
        if epoch % 1 == 0:
            
            with torch.no_grad():
                valid_loss = 0.
                
                for _, mini_batch in enumerate(tqdm(valid_loader)):

                    batch_qid, batch_pos_id, batch_neg_id, batch_class_label_vec, batch_phrase_label_vec = mini_batch
                    q_emb = valid_loader.dataset.qid2emb(batch_qid)
                    d_emb = valid_loader.dataset.cid2emb(batch_pos_id)
                    neg_d_emb = valid_loader.dataset.cid2emb(batch_neg_id)

                    q_emb, d_emb, neg_d_emb = q_emb.to('cuda'), d_emb.to('cuda'), neg_d_emb.to('cuda')
                    batch_Y = torch.FloatTensor(batch_class_label_vec).to('cuda')
                    batch_Y2 = torch.FloatTensor(batch_phrase_label_vec).to('cuda')        

                    batch_loss = model.get_fine_tune_loss(q_emb, d_emb, neg_d_emb, batch_Y, batch_Y2)
                    valid_loss += batch_loss.item()
                
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            early_stop = 0
            is_improved = True
            torch.save(model.state_dict(), TSI_module_path)
            
        if is_improved == False:
            early_stop += 1
            
        if early_stop >= early_stop_epochs:
            break

        print('Train Loss: {:.3f}'.format(epoch_loss))
        print('Valid Loss: {:.3f}'.format(valid_loss))


    # Test
    retriever = EvaluateRetrieval()
    with open(data_path + '/qrels', 'rb') as f:
        qrels = pickle.load(f)

    test_queries_emb = torch.load(data_path + '/backbone_query_emb.pt').cpu()
    with torch.no_grad():
        model.load_state_dict(torch.load(TSI_module_path))
        corpus_emb_TSI, _, _ = model(corpus_emb)
        queries_emb_TSI, _, _ = model(test_queries_emb)

        print("Title results")
        score_mat = torch.matmul(queries_emb_TSI, corpus_emb_TSI.T)
        backbone_results = eval_full_score_mat(score_mat, qid_list, cid_list)
        metrics = retriever.evaluate(qrels, backbone_results, topks)
        print_metrics(metrics)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    parser.add_argument('--base_path', default='./Outputs/', type=str)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--lambda_IL', type=float, default=1e-1)
    parser.add_argument('--D_d_size', type=int, default=100)
    parser.add_argument('--neg_numbers', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--M', type=int, default=3)
    parser.add_argument('--early_stop_epochs', type=int, default=10)
    parser.add_argument('--random_seed', type=int, default=1231)

    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    main(args)

