#TODO
#COLBERT TEXT RETRIEVAL
import ragatouille
import os
import json

def RAG_score(gts, res, index_name='cartoon_con'):
    idx_path = f'./visual_argument_experiments/.ragatouille/colbert/indexes/{index_name}'
    if os.path.exists(idx_path):
        rag = ragatouille.RAGPretrainedModel.from_index(idx_path)
    else:
        rag = ragatouille.RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
        index_path = rag.index(index_name=index_name, collection=[gt[0] for gt in gts.values()], document_ids=list(gts.keys()))
        rag = ragatouille.RAGPretrainedModel.from_index(index_path)
    
    scores = []
    for id, query in res.items():
        
        result = rag.search(query=query[0], doc_ids=[id])
        scores.append(result[0]['score'])

    return sum(scores) / len(scores)
    
if __name__ == "__main__":
    with open("./visual_argument_experiments/results/task4/instructblip_0_gts.json") as f:
        gts = json.load(f)
    with open("./visual_argument_experiments/results/task4/instructblip_0_res.json") as f:
        res = json.load(f)
    print(RAG_score(gts, res))