# Our training code is based on https://github.com/beir-cellar/beir/blob/main/examples/retrieval/training/train_sbert_BM25_hardnegs.py

import argparse
import numpy as np
import pickle
import random
import torch
from tqdm import tqdm
import os
import logging

from sentence_transformers import SentenceTransformer, models, losses, InputExample
from beir import util, LoggingHandler
from beir.retrieval.train import TrainRetriever
from beir import util, LoggingHandler
from beir.retrieval.evaluation import EvaluateRetrieval

from Utils.utils import *

def main(args):
    data_path = args.data_path
    base_path = args.base_path
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    batch_size = args.batch_size
    model_name = args.model_name
    neg_numbers = args.neg_numbers

    logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S',level=logging.INFO, handlers=[LoggingHandler()])
    retriever = EvaluateRetrieval()

    with open(data_path + '/corpus', 'rb') as f:
        corpus = pickle.load(f)
    with open(data_path + '/Promptgator_training_queries', 'rb') as f:
        total_queries = pickle.load(f)

    with open('./Outputs/CSFCube_training_BM25_results', 'rb') as f:
        BM25_top500_results = pickle.load(f)

    total_qrels = {}
    for qid in total_queries:
        cid, idx = qid.split('_')
        total_qrels[qid] = {cid: 1}

    with open(data_path + '/CSFCube_training_qid_list', 'rb') as f:
        total_queries_list = pickle.load(f)

    training_queries_list, valid_queries_list = total_queries_list[:-(len(total_queries_list) // 10)], total_queries_list[-(len(total_queries_list) // 10):]

    trainig_qrels = {}
    valid_qrels = {}
    training_queries = {}
    valid_queries = {}

    for qid in training_queries_list:
        trainig_qrels[qid] = total_qrels[qid]
        training_queries[qid] = total_queries[qid]

    for qid in valid_queries_list:
        valid_qrels[qid] = total_qrels[qid]    
        valid_queries[qid] = total_queries[qid]
        
    valid_corpus = {}
    for cid in corpus:
        valid_corpus[cid] = {}
        valid_corpus[cid]['title'] = ''
        valid_corpus[cid]['text'] = corpus[cid]['text']

    diff_dict = return_topK_result(BM25_top500_results, neg_numbers)
    train_triplets = {}

    for qid in trainig_qrels:
        if qid not in diff_dict: continue
        pos_pids = set(trainig_qrels[qid].keys())
        neg_pids = set(diff_dict[qid].keys())
        neg_pids = neg_pids - pos_pids
        if len(pos_pids) == 0: continue
        if len(neg_pids) == 0: continue

        train_triplets[qid] = {'query': training_queries[qid], 'pos': list(pos_pids), 'hard_neg': list(neg_pids)}

    if args.model_name == 'allenai/specter2_base':
        model = SentenceTransformer(model_name)
        model._modules["1"].pooling_mode_mean_tokens = False
        model._modules["1"].pooling_mode_cls_token = True        
    elif args.model_name == 'facebook/contriever-msmarco':
        model = SentenceTransformer(model_name)
        model._modules["1"].pooling_mode_mean_tokens = True
        model._modules["1"].pooling_mode_cls_token = False

    retriever = TrainRetriever(model=model, batch_size=batch_size)
    train_dataset = Standard_TripletDataset(train_triplets, corpus=corpus)
    train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
    train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score, scale=1)
    ir_evaluator = retriever.load_ir_evaluator(valid_corpus, valid_queries, valid_qrels)

    ### Provide model save path
    dataset = data_path.split('/')[-1]
    model_save_path = os.path.join(base_path, "output", "{}-v3-{}".format(model_name, dataset))
    os.makedirs(model_save_path, exist_ok=True)

    num_epochs, evaluation_steps = 30, 10000
    warmup_steps = 1000

    optimizer_params = {'lr': learning_rate, 'correct_bias': False}

    retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                    evaluator=ir_evaluator, 
                    epochs=num_epochs,
                    output_path=model_save_path,
                    warmup_steps=warmup_steps,
                    evaluation_steps=evaluation_steps,
                    optimizer_params = optimizer_params,
                    weight_decay=weight_decay,
                    use_amp=True)

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    parser.add_argument('--base_path', default='./Outputs/', type=str)
    parser.add_argument('--model_name', default='facebook/contriever-msmarco', type=str)
    parser.add_argument('--learning_rate', type=float, default=1e-6)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--neg_numbers', type=int, default=50)
    parser.add_argument('--random_seed', type=int, default=1231)

    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    main(args)

