from beir.reranking.models import CrossEncoder
from beir.reranking import Rerank
from beir.retrieval.search.sparse import SparseSearch
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
import json
import time


class BeirModels:
    def __init__(self):
        model_path = "BeIR/sparta-msmarco-distilbert-base-v1"
        self.sparse_model = SparseSearch(models.SPARTA(model_path), batch_size=128)
        self.dense_model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"), batch_size=16)
        self.cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-electra-base')
        self.rerank_model = Rerank(self.cross_encoder_model, batch_size=128)

    def sparse_recall(self, corpus, queries, topk=8):
        result = self.sparse_model.search(corpus, queries, score_function="", top_k=topk)
        return result

    def dense_recall(self, corpus, queries, topk=8):
        result = self.dense_model.search(corpus, queries, score_function="cos_sim", top_k=topk)
        return result

    def rerank(self, corpus, queries, recall_result, topk=8):
        result = self.rerank_model.rerank(corpus, queries, recall_result, top_k=topk)
        return result


def load_corpus():
    corpus = {}
    corpus_path = r"../knowledge_segmentation/data/segments_of_books_0704.json"
    with open(corpus_path, "r", encoding="utf-8") as fi:
        temp = json.load(fi)
    for k, v in temp.items():
        if v["seg_id"] > 500:
            continue
        doc_id = "doc" + str(v["seg_id"])
        text = k
        if doc_id in corpus:
            print(doc_id, "doc id repeated")
            continue
        corpus[doc_id] = {"title": "", "text": text}
    return corpus


def load_queries():
    queries = {}
    queries_path = r"data/dev.json"
    with open(queries_path, "r", encoding="utf-8") as fi:
        for i, line in enumerate(fi.readlines()):
            if i > 10:
                continue
            q_id = "q" + str(i)
            if q_id in queries:
                print(q_id, "q id repeated")
                continue
            temp = json.loads(line)
            question = temp["question"]
            queries[q_id] = question

    return queries


def main():
    model = BeirModels()
    print("models are ready.")
    corpus = load_corpus()
    queries = load_queries()
    # s1 = time.time()
    import os
    if os.path.exists("data/sparse_result.json"):
        with open("data/sparse_result.json", "r", encoding="utf-8") as fi:
            sparse_result = json.load(fi)
    else:
        sparse_result = model.sparse_recall(corpus, queries, topk=8)
        s2 = time.time()
        # print("sparse model cost: ", s2 - s1)
        with open("data/sparse_result.json", "w", encoding="utf-8") as fo:
            fo.write(json.dumps(sparse_result, ensure_ascii=False))

    if os.path.exists("data/dense_result.json"):
        with open("data/dense_result.json", "r", encoding="utf-8") as fi:
            dense_result = json.load(fi)
    else:
        dense_result = model.dense_recall(corpus, queries, topk=8)
        # print("dense model cost: ", time.time() - s2)
        with open("data/dense_result.json", "w", encoding="utf-8") as fo:
            fo.write(json.dumps(dense_result, ensure_ascii=False))
        # print(dense_result)

    sparse_rerank_result = model.rerank(corpus, queries, recall_result=sparse_result, topk=8)
    with open("data/sparse_rerank_result.json", "w", encoding="utf-8") as fo:
        fo.write(json.dumps(sparse_rerank_result, ensure_ascii=False))

    dense_rerank_result = model.rerank(corpus, queries, recall_result=dense_result, topk=8)
    with open("data/dense_rerank_result.json", "w", encoding="utf-8") as fo:
        fo.write(json.dumps(dense_rerank_result, ensure_ascii=False))


if __name__ == '__main__':
    main()






