from prompts import *

import argparse
import json
from tqdm import tqdm
import time
import re
import random
import threading
import queue
import math

import openai
from SPARQLWrapper import SPARQLWrapper, JSON
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import CrossEncoder

import model.BLINK.elq.main_dense as elq_main_dense
from qwikidata.linked_data_interface import get_entity_dict_from_api


def prepare_data(dataset_name, use_golden_topic):
    if dataset_name == "webqsp":
        dataset_path = "./data/webqsp/simple-WebQSP_test.jsonl"
    else:
        dataset_path = "./data/cwq/simple-CWQ_test.jsonl"
    question_ids = []
    questions = []
    topic_entity_names = []
    topic_entity_ids = []
    with open(dataset_path, "r", encoding="utf-8") as dataset_file:
        lines = dataset_file.readlines()
        for line in lines:
            line_data = json.loads(line)
            question_ids.append(line_data["Id"])
            questions.append(line_data["Question"])
            if use_golden_topic:
                topic_entity_names.append(line_data["TopicEntityName"])
                topic_entity_ids.append(line_data["TopicEntityId"])
            else:
                topic_entity_names.append([])
                topic_entity_ids.append([])
    return question_ids, questions, topic_entity_names, topic_entity_ids


def prepare_entity_linker(threshold):
    # ELQ
    models_path = "./model/BLINK/models/" # the path where you stored the ELQ models
    config = {
        "interactive": False,
        "biencoder_model": models_path+"elq_webqsp_large.bin",
        "biencoder_config": models_path+"elq_large_params.txt",
        "cand_token_ids_path": models_path+"entity_token_ids_128.t7",
        "entity_catalogue": models_path+"entity.jsonl",
        "entity_encoding": models_path+"all_entities_large.t7",
        "output_path": "ELQ_logs/", # logging directory
        "faiss_index": "hnsw",
        "index_path": models_path+"faiss_hnsw_index.pkl",
        "num_cand_mentions": 10,
        "num_cand_entities": 10,
        "threshold_type": "joint",
        "threshold": threshold,
    }
    args = argparse.Namespace(**config)
    models = elq_main_dense.load_models(args, logger=None)
    return models, args


def reason_with_llm(prompt, api_key, input_token_num, output_token_num):
    # gpt-3.5-turbo
    openai.api_key = api_key
    input_message = [
        {"role": "system", "content": "You are an AI assistant with the ability to provide insightful responses and make informed judgments."},
        {"role": "user", "content": prompt}
    ]
    success = False
    while not success:
        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo-0125",
                messages=input_message,
                max_tokens=100,
                temperature=0
            )
            success = True
        except Exception as e:
            print("Exception in LLM:")
            print(e)
            time.sleep(2 + 2 * random.random())
    if "content" not in response["choices"][0]["message"]:
        response_message = "null"
    else:
        response_message = response.choices[0].message.content
    # print(response_message+"\n")
    input_token_num += response.usage.prompt_tokens
    output_token_num += response.usage.completion_tokens
    return response_message, input_token_num, output_token_num


def determine(question, llm_api_key, llm_input_len, llm_output_len):
    prompt = prompt_judge.format(question)
    response, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)
    
    # extract the {yes} or {no} in llm's response
    answer_begin = response.find('{')
    answer_end = response.find('}')
    if answer_begin == -1 or answer_end == -1:
        answer = "simple"
    else:
        answer = response[answer_begin+1:answer_end]
    
    if answer.lower() == "simple":
        answer_flag = True
    else:
        answer_flag = False
    
    return response, answer_flag, llm_input_len, llm_output_len


def generate_answer(question, previous_subQAs, reference_triplets, llm_api_key, llm_input_len, llm_output_len, use_relation_retrieval):
    if not use_relation_retrieval:
        prompt_for_answer = prompt_answer
    else:
        prompt_for_answer = prompt_answer_rel
    if len(previous_subQAs) == 0:
        prompt = prompt_for_answer.format(question, '\n'.join(reference_triplets))
    else:
        prompt = prompt_for_answer.format(question, '\n'.join(reference_triplets) + "\nReferences: " + previous_subQAs.rstrip("\n"))
    response, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)

    # extract the {answer_entity} in llm's response
    answer_pattern = r'\{([^}]*)\}'
    answers = re.findall(answer_pattern, response)
    if (len(answers) == 0) or ((len(answers) == 1) and (answers[0].lower() in ["none", "unknown"])):
        prompt = prompt_answer_extract.format(question, response)
        response_2, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)
        answers = re.findall(answer_pattern, response_2)
    answers = list(set(answers))
    answer_entity = "; ".join(answers)
    
    return response, answer_entity, llm_input_len, llm_output_len


def decompose(question, llm_api_key, llm_input_len, llm_output_len):
    prompt = prompt_decompose.format(question)
    response, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)

    # extract the subquestions in llm's response
    subquestions = []
    response = re.sub(r'\{\[\#(\d+)\]\}', r'([#\1])', response) # replace wrong pattern {[#x]} with ([#x])
    lines = response.split('\n')
    for line in lines:
        begini = line.find('{')
        endi = line.find('}')
        if begini != -1 and endi != -1 and begini + 1 < endi:
            subquestion = line[begini+1:endi]
            # modify error format
            subquestion = re.sub(r'in subquestion-(\d+) and subquestion-(\d+)', r'in ([#\1]) and ([#\2])', subquestion)
            subquestion = re.sub(r'in subquestion-(\d+)', r'in ([#\1])', subquestion)
            subquestion = re.sub(r'the answers? to subquestion-(\d+) and subquestion-(\d+)', r'([#\1]) and ([#\2])', subquestion)
            subquestion = re.sub(r'the answers? to subquestion-(\d+)', r'([#\1])', subquestion)
            subquestion = re.sub(r'from subquestion-(\d+)', r'from ([#\1])', subquestion)
            subquestions.append(subquestion)
    
    return response, subquestions, llm_input_len, llm_output_len


def entity_linking(question, entity_linker_model, entity_linker_args, elqId2wikiId):
    # entity linker: ELQ
    data_to_link = [{
        "id": 0,
        "text": question.lower()
    }]
    predictions = elq_main_dense.run(entity_linker_args, None, *entity_linker_model, test_data=data_to_link)
    pred = predictions[0]
    # id in ELQ results -> wikidata id
    wikiIds = []
    for triple in pred["pred_triples"]:
        Id = triple[0]
        wikiId = elqId2wikiId.get(Id, 'null')
        wikiIds.append(wikiId)
    # wikidata id -> freebase id
    fbIds = []
    for wikiId in wikiIds:
        if wikiId == 'null' or wikiId[0] != 'Q':
            fbId = 'null'
        else:
            qdict = {}
            success = False
            while not success:
                try:
                    qdict = get_entity_dict_from_api(wikiId)
                    success = True
                except Exception as e:
                    print("Exception in wikiId -> freebaseId:")
                    print(e)
                    if "Not Found" in str(e):
                        break
                    time.sleep(2)
            if not success:
                fbId = "none"
            elif 'P646' not in qdict['claims']:
                fbId = "none"
            else:
                p646 = qdict["claims"]["P646"]
                p646 = p646[0]
                fbId = p646["mainsnak"]["datavalue"]["value"]
                if fbId.startswith("/m/") or fbId.startswith("/g/"):
                    fbId = fbId[1] + "." + fbId[3:]
                else:
                    fbId = "none"
        fbIds.append(fbId)
    # select topic entity
    topic_entity_name = ""
    topic_entity_id = "null"
    for i in range(len(fbIds)):
        if fbIds[i] != "none" and fbIds[i] != "null":
            topic_entity_id = fbIds[i]
            topic_entity_name = pred["pred_tuples_string"][i][0]
            break
    return topic_entity_name, topic_entity_id


SPARQLPATH = "http://localhost:8890/sparql"
def sparql_search(sparql_txt):
    success = False
    while not success:
        try:
            sparql = SPARQLWrapper(SPARQLPATH)
            sparql.setQuery(sparql_txt)
            sparql.setReturnFormat(JSON)
            results = sparql.query().convert()
            
            success = True
        except Exception as e:
            print("Exception in SPARQLsearch:")
            print(e)
            if "Error HTTP/1.1 404 File not found" not in str(e):
                break
    if success:
        return results["results"]["bindings"]
    else:
        return []


def process_relation(sparql_results):
    relations = []
    for rel in sparql_results:
        relations.append(rel["relation"]["value"].replace("http://rdf.freebase.com/ns/",""))
    return relations


sparql_txt_find_name = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT ?name
WHERE {
    {
        ns:%s ns:type.object.name ?name .
    }
    UNION
    {
        ns:%s <http://www.w3.org/2002/07/owl#sameAs> ?name .
    }
}"""
wikiPrefix = "http://www.wikidata.org/entity/Q"
def process_entity(sparql_results):
    entities = []
    for ent in sparql_results:
        entityId = ent["entity"]["value"].replace("http://rdf.freebase.com/ns/","")
        entityName = entityId
        if len(entityId) >= 2:
            if entityId[0:2] in ["m.", "g."]:
                results = sparql_search(sparql_txt_find_name % (entityId, entityId))
                if len(results) != 0:
                    entityName = results[0]["name"]["value"]
                    if wikiPrefix in entityName:
                        if len(results) > 1:
                            entityName = results[1]["name"]["value"]
                        else:
                            wikiId = entityName.replace(wikiPrefix[:-1], "")
                            qdict = {}
                            success = False
                            while not success:
                                try:
                                    qdict = get_entity_dict_from_api(wikiId)
                                    success = True
                                except Exception as e:
                                    print("Exception in wikiId -> entityName:")
                                    print(e)
                                    time.sleep(2)
                            if "en" in qdict["labels"]:
                                entityName = qdict["labels"]["en"]["value"]
                            else:
                                entityName = entityId
        entityName = entityName.replace("\n", "\\n")
        entities.append([entityId, entityName])
    return entities


def construct_candidates(candidates, rel2ents, topic_entity, relation, entities, ishead, istwohop, relations1hop=[]):
    if not istwohop:
        entFor2hop = [ent for ent in entities if (ent[0].startswith('m.') or ent[0].startswith('g.')) and (ent[0] == ent[1])]
        if relation in rel2ents:
            rel2ents[relation].extend(entFor2hop)
        else:
            rel2ents[relation] = entFor2hop
        # add to candidates
        for entity in entities:
            if topic_entity != entity[1]:
                if ishead:
                    candidate = "< {} > [SEP] < {} > [SEP] < {} >".format(topic_entity, relation, entity[1])
                else:
                    candidate = "< {} > [SEP] < {} > [SEP] < {} >".format(entity[1], relation, topic_entity)
                candidates.add(candidate)
    else:
        if ishead:
            removeCands = [cand for cand in candidates if (cand.split(" [SEP] ")[1][2:-2] in relations1hop) and (cand.split(" [SEP] ")[2][2:-2] == topic_entity)]
            for cand in removeCands:
                for entity in entities:
                    if cand.split(" [SEP] ")[0][2:-2] != entity[1]:
                        candidates.add("< {} > [SEP] < {} > [SEP] < {} > [SEP] < {} >".format(cand.split(" [SEP] ")[0][2:-2], cand.split(" [SEP] ")[1][2:-2], relation, entity[1]))
                candidates.remove(cand)
        else:
            removeCands = [cand for cand in candidates if (cand.split(" [SEP] ")[1][2:-2] in relations1hop) and (cand.split(" [SEP] ")[0][2:-2] == topic_entity)]
            for cand in removeCands:
                for entity in entities:
                    if entity[1] != cand.split(" [SEP] ")[2][2:-2]:
                        candidates.add("< {} > [SEP] < {} > [SEP] < {} > [SEP] < {} >".format(entity[1], relation, cand.split(" [SEP] ")[1][2:-2], cand.split(" [SEP] ")[2][2:-2]))
                candidates.remove(cand)
    return candidates, rel2ents


def construct_candidates_rel_1hop(candidates, rel2ents, candidate2isTail, topic_entity, relation, entities, ishead):
    entFor2hop = [ent[0] for ent in entities if (ent[0].startswith('m.') or ent[0].startswith('g.')) and (ent[0] == ent[1])]
    if relation in rel2ents:
        rel2ents[relation].extend(entFor2hop)
    else:
        rel2ents[relation] = entFor2hop
    # add to candidates
    for entity in entities:
        if topic_entity != entity[1]:
            if ishead:
                candidate = "< {} > [SEP] < {} >".format(topic_entity, relation)
            else:
                candidate = "< {} > [SEP] < {} >".format(relation, topic_entity)
            candidates.add(candidate)
            if candidate not in candidate2isTail:
                candidate2isTail[candidate] = ishead
    return candidates, rel2ents, candidate2isTail


def construct_candidates_rel_2hop(candidates, candidate2isTail, relation, relations1hop, ishead):
    if ishead:
        changeCands = [cand for cand in candidates if (len(cand.split(" [SEP] ")) == 2) and (cand.split(" [SEP] ")[1][2:-2] in relations1hop)]
        for cand in changeCands:
            candidate = "< {} > [SEP] < {} > [SEP] < {} >".format(cand.split(" [SEP] ")[0][2:-2], cand.split(" [SEP] ")[1][2:-2], relation)
            candidates.add(candidate)
            if candidate not in candidate2isTail:
                candidate2isTail[candidate] = ishead
            # candidates.remove(cand)
    else:
        changeCands = [cand for cand in candidates if (len(cand.split(" [SEP] ")) == 2) and (cand.split(" [SEP] ")[0][2:-2] in relations1hop)]
        for cand in changeCands:
            candidate = "< {} > [SEP] < {} > [SEP] < {} >".format(relation, cand.split(" [SEP] ")[0][2:-2], cand.split(" [SEP] ")[1][2:-2])
            candidates.add(candidate)
            if candidate not in candidate2isTail:
                candidate2isTail[candidate] = ishead
            # candidates.remove(cand)
    return candidates, candidate2isTail


sparql_txt_find_tail_rel = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?relation
WHERE {
    ns:%s ?relation ?entity .
    FILTER(?relation != ns:common.topic.description)
}"""
sparql_txt_find_head_rel = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?relation
WHERE {
    ?entity ?relation %s .
}"""
sparql_txt_find_tail = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?entity
WHERE {{
    ns:{0} ns:{1} ?entity .
}}"""
sparql_txt_find_head = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?entity
WHERE {{
    ?entity ns:{1} {0} .
}}"""
def find_candidates(question, topic_entity_name, topic_entity_id, use_relation_retrieval):
    width_1hop_rel = 10
    rel_for_2hop = 5
    width_2hop_rel = 5

    rel2isTail = dict()

    # 1 hop relations
    if topic_entity_id.endswith('@en') or topic_entity_id.endswith('^^xsd:dateTime'):
        results = []
    else:
        results = sparql_search(sparql_txt_find_tail_rel % topic_entity_id)
    relations_tail = process_relation(results)
    for rel in relations_tail:
        rel2isTail[rel] = "True"
    if topic_entity_id.endswith('@en') or topic_entity_id.endswith('^^xsd:dateTime'):
        results = sparql_search(sparql_txt_find_head_rel % topic_entity_id)
    else:
        results = sparql_search(sparql_txt_find_head_rel % ("ns:" + topic_entity_id))
    relations_head = process_relation(results)
    for rel in relations_head:
        if (rel in rel2isTail) and (rel2isTail[rel] == "True"):
            rel2isTail[rel] = "both"
        else:
            rel2isTail[rel] = "False"
    relations = list(set(relations_tail + relations_head))
    
    # prune 1 hop relations
    if len(relations) > rel_for_2hop:
        relations_1hop = dense_retrieve(list(relations), question, min(width_1hop_rel, len(relations)))
    else:
        relations_1hop = relations
    
    # 1 hop entities
    candidates = set()
    rel2ents = dict()
    candidate2isTail = dict()
    for relation in relations_1hop:
        if rel2isTail[relation] == "True":
            results = sparql_search(sparql_txt_find_tail.format(topic_entity_id, relation))
            entities = process_entity(results)
            if not use_relation_retrieval:
                candidates, rel2ents = construct_candidates(candidates, rel2ents, topic_entity_name, relation, entities, ishead=True, istwohop=False)
            else:
                candidates, rel2ents, candidate2isTail = construct_candidates_rel_1hop(candidates, rel2ents, candidate2isTail, topic_entity_name, relation, entities, ishead=True)
        elif rel2isTail[relation] == "False":
            if topic_entity_id.endswith('@en') or topic_entity_id.endswith('^^xsd:dateTime'):
                results = sparql_search(sparql_txt_find_head.format(topic_entity_id, relation))
            else:
                results = sparql_search(sparql_txt_find_head.format("ns:" + topic_entity_id, relation))
            entities = process_entity(results)
            if not use_relation_retrieval:
                candidates, rel2ents = construct_candidates(candidates, rel2ents, topic_entity_name, relation, entities, ishead=False, istwohop=False)
            else:
                candidates, rel2ents, candidate2isTail = construct_candidates_rel_1hop(candidates, rel2ents, candidate2isTail, topic_entity_name, relation, entities, ishead=False)
        else: # "both"
            results = sparql_search(sparql_txt_find_tail.format(topic_entity_id, relation))
            entities = process_entity(results)
            if not use_relation_retrieval:
                candidates, rel2ents = construct_candidates(candidates, rel2ents, topic_entity_name, relation, entities, ishead=True, istwohop=False)
            else:
                candidates, rel2ents, candidate2isTail = construct_candidates_rel_1hop(candidates, rel2ents, candidate2isTail, topic_entity_name, relation, entities, ishead=True)
            if topic_entity_id.endswith('@en') or topic_entity_id.endswith('^^xsd:dateTime'):
                results = sparql_search(sparql_txt_find_head.format(topic_entity_id, relation))
            else:
                results = sparql_search(sparql_txt_find_head.format("ns:" + topic_entity_id, relation))
            entities = process_entity(results)
            if not use_relation_retrieval:
                candidates, rel2ents = construct_candidates(candidates, rel2ents, topic_entity_name, relation, entities, ishead=False, istwohop=False)
            else:
                candidates, rel2ents, candidate2isTail = construct_candidates_rel_1hop(candidates, rel2ents, candidate2isTail, topic_entity_name, relation, entities, ishead=False)
    
    # entities for 2 hop searching
    entities_1hop = []
    ent2isTail = dict()
    if not use_relation_retrieval:
        num_ent_for_2hop = 200
    else:
        ent2rels = dict()
        num_ent_for_2hop = 10
    for relation in relations_1hop[:rel_for_2hop]:
        if relation in rel2ents:
            if len(rel2ents[relation]) > num_ent_for_2hop: # limit number of entity for each 1hop relation under num_ent_for_2hop
                entities_1hop.extend(random.sample(rel2ents[relation], num_ent_for_2hop))
            else:
                entities_1hop.extend(rel2ents[relation])
            for ent in rel2ents[relation]:
                if (ent[0] in ent2isTail) and (ent2isTail[ent[0]] != rel2isTail[relation]):
                    ent2isTail[ent[0]] = "both"
                else:
                    ent2isTail[ent[0]] = rel2isTail[relation]
                if use_relation_retrieval:
                    if ent[0] in ent2rels:
                        ent2rels[ent[0]].append(relation)
                    else:
                        ent2rels[ent[0]] = [relation]
    entities_1hop = set((ent[0],ent[1]) for ent in entities_1hop)

    # 2 hop fact triples/quadruples or relations
    for entity in entities_1hop:
        rel2isTail_2hoprel = dict()
        # 2 hop relations
        if ent2isTail[entity[0]] == "True":
            results = sparql_search(sparql_txt_find_tail_rel % entity[0])
            relations_2hop_cand = process_relation(results)
            for rel in relations_2hop_cand:
                rel2isTail_2hoprel[rel] = "True"
        elif ent2isTail[entity[0]] == "False":
            results = sparql_search(sparql_txt_find_head_rel % ("ns:" + entity[0]))
            relations_2hop_cand = process_relation(results)
            for rel in relations_2hop_cand:
                rel2isTail_2hoprel[rel] = "False"
        else: # "both"
            results = sparql_search(sparql_txt_find_tail_rel % entity[0])
            relations_2hop_cand_tail = process_relation(results)
            for rel in relations_2hop_cand_tail:
                rel2isTail_2hoprel[rel] = "True"
            results = sparql_search(sparql_txt_find_head_rel % ("ns:" + entity[0]))
            relations_2hop_cand_head = process_relation(results)
            for rel in relations_2hop_cand_head:
                if (rel in rel2isTail_2hoprel) and (rel2isTail_2hoprel[rel] == "True"):
                    rel2isTail_2hoprel[rel] = "both"
                else:
                    rel2isTail_2hoprel[rel] = "False"
            relations_2hop_cand = list(set(relations_2hop_cand_tail + relations_2hop_cand_head))
            
        # prune 2 hop relations
        if len(relations_2hop_cand) > width_2hop_rel:
            relations_2hop = dense_retrieve(list(relations_2hop_cand), question, min(width_2hop_rel, len(relations_2hop_cand)))
        else:
            relations_2hop = relations_2hop_cand
        
        if not use_relation_retrieval:
            # 2 hop entities
            for relation in relations_2hop:
                if rel2isTail_2hoprel[relation] == "True":
                    results = sparql_search(sparql_txt_find_tail.format(entity[0], relation))
                    entities = process_entity(results)
                    candidates, _ = construct_candidates(candidates, dict(), entity[1], relation, entities, ishead=True, istwohop=True, relations1hop=relations_1hop[:rel_for_2hop])
                elif rel2isTail_2hoprel[relation] == "False":
                    results = sparql_search(sparql_txt_find_head.format("ns:" + entity[0], relation))
                    entities = process_entity(results)
                    candidates, _ = construct_candidates(candidates, dict(), entity[1], relation, entities, ishead=False, istwohop=True, relations1hop=relations_1hop[:rel_for_2hop])
                else: # "both"
                    results = sparql_search(sparql_txt_find_tail.format(entity[0], relation))
                    entities = process_entity(results)
                    candidates, _ = construct_candidates(candidates, dict(), entity[1], relation, entities, ishead=True, istwohop=True, relations1hop=relations_1hop[:rel_for_2hop])
                    results = sparql_search(sparql_txt_find_head.format("ns:" + entity[0], relation))
                    entities = process_entity(results)
                    candidates, _ = construct_candidates(candidates, dict(), entity[1], relation, entities, ishead=False, istwohop=True, relations1hop=relations_1hop[:rel_for_2hop])
        else:
            # construct candidates with 2 hop relation
            for relation in relations_2hop:
                if rel2isTail_2hoprel[relation] == "True":
                    candidates, candidate2isTail = construct_candidates_rel_2hop(candidates, candidate2isTail, relation, ent2rels[entity], ishead=True)
                elif rel2isTail_2hoprel[relation] == "False":
                    candidates, candidate2isTail = construct_candidates_rel_2hop(candidates, candidate2isTail, relation, ent2rels[entity], ishead=False)
                else: # "both"
                    candidates, candidate2isTail = construct_candidates_rel_2hop(candidates, candidate2isTail, relation, ent2rels[entity], ishead=True)
                    candidates, candidate2isTail = construct_candidates_rel_2hop(candidates, candidate2isTail, relation, ent2rels[entity], ishead=False)

    # candidates in 2 hop from topic entity
    return list(candidates), candidate2isTail


def sparse_retrieve(candidates, question, width):
    # BM25
    if width == 0:
        return []
    tokenized_candidates = [cand.split(" ") for cand in candidates]
    bm25 = BM25Okapi(tokenized_candidates)
    tokenized_question = question.split(" ")
    scores = bm25.get_scores(tokenized_question)
    top_triplets = bm25.get_top_n(tokenized_question, candidates, n=width)
    # scores = sorted(scores, reverse=True)[:width]
    return top_triplets


dense_retriever = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v3', device="cuda:6")
def dense_retrieve(candidates, question, width):
    # Distilbert
    if width == 0:
        return []
    question_embeddings = dense_retriever.encode(question)
    candidates_embeddings = dense_retriever.encode(candidates)
    scores = util.dot_score(question_embeddings, candidates_embeddings)[0].tolist()
    results = [{"triplet": cand, "score": score} for cand, score in zip(candidates, scores)]
    top_results = sorted(results, key=lambda x: x["score"], reverse=True)[:width]
    top_triplets = [res["triplet"] for res in top_results]
    return top_triplets


reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cuda:6")
def rerank(candidates, question):
    # MiniLM
    if len(candidates) == 0:
        return []
    model_inputs = [[question, candidate] for candidate in candidates]
    scores = reranker.predict(model_inputs)
    results = [{"input": inp, "score": score} for inp, score in zip(model_inputs, scores)]
    results = sorted(results, key=lambda x: x["score"], reverse=True)
    reranked_triplets = [res["input"][1] for res in results]
    return reranked_triplets


def prune_entity(origin_entities):
    pruned_entities = []
    nameless = [ent for ent in origin_entities if ent[0] == ent[1] and (ent[0].startswith("m.") or ent[0].startswith("g."))]
    hasName = [ent for ent in origin_entities if ent not in nameless]
    if len(origin_entities) <= 20 and len(nameless) <= 1:
        return origin_entities
    if len(hasName) >= 20:
        pruned_entities.extend(random.sample(hasName, 20))
    else:
        pruned_entities.extend(hasName)
        pruned_entities.extend(random.sample(nameless, 1))
    return pruned_entities


sparql_txt_find_tail_2hop = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?entity
WHERE {{
    ns:{0} ns:{1} ?c .
    ?c ns:{2} ?entity .
}}"""
sparql_txt_find_head_2hop = """PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?entity
WHERE {{
    ?entity ns:{2} ?c .
    ?c ns:{1} {0} .
}}"""
def retrieve_triplets(question, topic_entity_name_gold, topic_entity_id_gold, entity_linker_model, entity_linker_args, elqId2wikiId, args, decompose_depth):
    # get topic entity
    topic_entity_name = topic_entity_name_gold
    topic_entity_id = topic_entity_id_gold
    if not args.use_golden_topic: # not use golden topic
        topic_entity_name_q, topic_entity_id_q = entity_linking(question, entity_linker_model, entity_linker_args, elqId2wikiId)
        topic_entity_name.append(topic_entity_name_q)
        topic_entity_id.append(topic_entity_id_q)
    elif decompose_depth > 0: # use golden topic
        topic_entity_name_q, topic_entity_id_q = entity_linking(question, entity_linker_model, entity_linker_args, elqId2wikiId)
        if (topic_entity_name_q != "null") and (topic_entity_name_q not in topic_entity_name) and (topic_entity_id_q not in topic_entity_id):
            topic_entity_name.append(topic_entity_name_q)
            topic_entity_id.append(topic_entity_id_q)
    
    # prepare candidates
    candidates = []
    candidate2isTail = {}
    for i in range(len(topic_entity_id)):
        candidates_1topic, candidate2isTail_1topic = find_candidates(question, topic_entity_name[i], topic_entity_id[i], args.use_relation_retrieval)
        candidates.extend(candidates_1topic)
        candidate2isTail.update(candidate2isTail_1topic)
    candidates = list(set(candidates))

    # retrive
    # sparse_result = sparse_retrieve(candidates, question, min(args.retrieve_width, len(candidates)))
    dense_result = dense_retrieve(candidates, question, min(args.retrieve_width, len(candidates)))
    # result_1 = set(sparse_result + dense_result)
    result_1 = set(dense_result)
    # rerank
    result = rerank(result_1, question)
    if not args.use_relation_retrieval:
        references = result[:min(args.reference_num, len(result))]
    else:
        references_relation = result[:min(args.reference_num, len(result))]
    
    # add tail / head entity to the retrieved relations
    if args.use_relation_retrieval:
        references = []
        for ref in references_relation:
            elements = [e[2:-2] for e in ref.split(" [SEP] ")]
            if len(elements) == 2: # 1hop relation
                if candidate2isTail[ref]:
                    results = sparql_search(sparql_txt_find_tail.format(topic_entity_id[topic_entity_name.index(elements[0])], elements[1]))
                    entities = process_entity(results)
                    entities = prune_entity(entities)
                    references.append("< {} > [SEP] < {} > [SEP] [ {} ]".format(elements[0], elements[1], " ".join(["< " + ent[1] + " >" for ent in entities])))
                else:
                    eid = topic_entity_id[topic_entity_name.index(elements[1])]
                    if eid.endswith('@en') or eid.endswith('^^xsd:dateTime'):
                        results = sparql_search(sparql_txt_find_head.format(eid, elements[0]))
                    else:
                        results = sparql_search(sparql_txt_find_head.format("ns:" + eid, elements[0]))
                    entities = process_entity(results)
                    entities = prune_entity(entities)
                    references.append("[ {} ] [SEP] < {} > [SEP] < {} >".format(" ".join(["< " + ent[1] + " >" for ent in entities]), elements[0], elements[1]))
            elif len(elements) == 3: # 1hop & 2hop relations
                if candidate2isTail[ref]:
                    results = sparql_search(sparql_txt_find_tail_2hop.format(topic_entity_id[topic_entity_name.index(elements[0])], elements[1], elements[2]))
                    entities = process_entity(results)
                    entities = prune_entity(entities)
                    references.append("< {} > [SEP] < {} > [SEP] < {} > [SEP] [ {} ]".format(elements[0], elements[1], elements[2], " ".join(["< " + ent[1] + " >" for ent in entities])))
                else:
                    eid = topic_entity_id[topic_entity_name.index(elements[2])]
                    if eid.endswith('@en') or eid.endswith('^^xsd:dateTime'):
                        results = sparql_search(sparql_txt_find_head_2hop.format(eid, elements[1], elements[0]))
                    else:
                        results = sparql_search(sparql_txt_find_head_2hop.format("ns:" + eid, elements[1], elements[0]))
                    entities = process_entity(results)
                    entities = prune_entity(entities)
                    references.append("[ {} ] [SEP] < {} > [SEP] < {} > [SEP] < {} >".format(" ".join(["< " + ent[1] + " >" for ent in entities]), elements[0], elements[1], elements[2]))
    
    return [topic_entity_name, topic_entity_id], references


def find_nums(string):
    matches = re.findall(r'\[#\d+\]', string)  # [#123]
    matches_and_nums = []
    for match in matches:
        match_len = len(match)
        num = int(match[2:match_len-1])
        if (match, num) not in matches_and_nums:
            matches_and_nums.append((match, num))
    return matches_and_nums


def integrate(question, previous_subQAs, llm_api_key, llm_input_len, llm_output_len):
    prompt = prompt_integrate.format(question, previous_subQAs.rstrip("\n"))
    response, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)

    # extract the {answer_entity} in llm's response
    answer_pattern = r'\{([^}]*)\}'
    answers = re.findall(answer_pattern, response)
    if (len(answers) == 0) or ((len(answers) == 1) and (answers[0].lower() in ["none", "unknown"])):
        prompt = prompt_answer_extract.format(question, response)
        response_2, llm_input_len, llm_output_len = reason_with_llm(prompt, llm_api_key, llm_input_len, llm_output_len)
        answers = re.findall(answer_pattern, response_2)
    answer_entity = "; ".join(answers)
        
    return response, answer_entity, llm_input_len, llm_output_len


def handle_question(question, previous_subQAs, depth, topic_entity_name, topic_entity_id, entity_linker_model, entity_linker_args, elqId2wikiId, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, args, llm_input_len, llm_output_len):
    response_phase1, flag_onestep, llm_input_len, llm_output_len = determine(question, args.llm_api_key, llm_input_len, llm_output_len)
    # determine whether the question can be answered through one-step reasoning
    answer_entity = ""
    if flag_onestep or depth == args.max_decompose_depth:
        topic_entity, reference_triplets = retrieve_triplets(question, topic_entity_name, topic_entity_id, entity_linker_model, entity_linker_args, elqId2wikiId, args, depth)
        response_phase2, answer_entity, llm_input_len, llm_output_len = generate_answer(question, previous_subQAs, reference_triplets, args.llm_api_key, llm_input_len, llm_output_len, args.use_relation_retrieval)
        response_chain.append([response_phase1, response_phase2])
        answer_chain.append([str(flag_onestep), answer_entity])
        topic_entity_chain.append(topic_entity)
        reference_triplets_chain.append(reference_triplets)
    else:
        response_phase3, subquestions, llm_input_len, llm_output_len = decompose(question, args.llm_api_key, llm_input_len, llm_output_len)
        response_chain.append([response_phase1, response_phase3])
        answer_chain.append([str(flag_onestep), ';'.join(subquestions)])
        topic_entity_chain.append([])
        reference_triplets_chain.append([])
        
        if len(subquestions) == 0:  # decompose error
            topic_entity, reference_triplets = retrieve_triplets(question, topic_entity_name, topic_entity_id, entity_linker_model, entity_linker_args, elqId2wikiId, args, depth)
            response_phase2, answer_entity, llm_input_len, llm_output_len = generate_answer(question, previous_subQAs, reference_triplets, args.llm_api_key, llm_input_len, llm_output_len, args.use_relation_retrieval)
            response_chain.append(["decompose error", response_phase2])
            answer_chain.append(["decompose error", answer_entity])
            topic_entity_chain.append(topic_entity)
            reference_triplets_chain.append(reference_triplets)
        else:
            # solve subquestions
            previous_subQAs = ""
            subanswers = []
            for i in range(len(subquestions)):
                # replace the [#xxx] in subquestion with its answer entity
                subquestion = subquestions[i]
                previous_answers = find_nums(subquestion)
                for num_str, num in previous_answers:
                    if num < i + 1:
                        subquestion = subquestion.replace(num_str, subanswers[num-1])
                # answer the subquestion
                answer_entity, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, llm_input_len, llm_output_len = handle_question(subquestion, previous_subQAs, depth+1, topic_entity_name, topic_entity_id, entity_linker_model, entity_linker_args, elqId2wikiId, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, args, llm_input_len, llm_output_len)
                subanswers.append(answer_entity)
                if len(answer_entity) == 0:
                    previous_subQAs += subquestion + " -- " + response_chain[-1][1] + "\n"
                else:
                    previous_subQAs += subquestion + " -- " + answer_entity + "\n"
            # integrate the subQAs to solve the question
            response_phase2, answer_entity, llm_input_len, llm_output_len = integrate(question, previous_subQAs, args.llm_api_key, llm_input_len, llm_output_len)
            response_chain.append(["integrate", response_phase2])
            answer_chain.append(["integrage", answer_entity])
            topic_entity_chain.append([])
            reference_triplets_chain.append([])

    return answer_entity, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, llm_input_len, llm_output_len


def write_to_file(args, question_id, question, answer, topic_entity_chain, reference_triplets_chain, answer_chain, response_chain):
    if args.use_relation_retrieval:
        output_path_add = "-rel"
    else:
        output_path_add = ""
    
    if args.use_golden_topic:
        output_path_add += "-goldenTopic"
    else:
        output_path_add += ""
    
    if args.dataset == "webqsp":
        output_path = "./output/gpt3.5-WebQSP_test{}.jsonl".format(output_path_add)
    else:
        output_path = "./output/gpt3.5-CWQ_test{}.jsonl".format(output_path_add)
    output_content = {
        "Id": question_id,
        "Question": question,
        "Answer": answer,
        "TopicEntityChain": topic_entity_chain,
        "ReferenceTripletsChain": reference_triplets_chain,
        "AnswerChain": answer_chain,
        "ResponseChain": response_chain
    }
    with open(output_path, "a", encoding="utf-8") as output_file:
        output_file.write(json.dumps(output_content) + "\n")


file_write_lock = threading.Lock()
queue_lock = threading.Lock()

def process_batch(batch_index, questions, question_ids, topic_entity_names, topic_entity_ids, entity_linker_model, entity_linker_args, elqId2wikiId, args, llm_io_lens):
    llm_input_len = 0
    llm_output_len = 0

    for i in tqdm(range(len(questions))):
        question_id = question_ids[i]
        question = questions[i]
        topic_entity_name = topic_entity_names[i]
        topic_entity_id = topic_entity_ids[i]

        response_chain = []
        answer_chain = []
        topic_entity_chain = []
        reference_triplets_chain = []
        answer, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, llm_input_len, llm_output_len = handle_question(question, "", 0, topic_entity_name, topic_entity_id, entity_linker_model, entity_linker_args, elqId2wikiId, response_chain, answer_chain, topic_entity_chain, reference_triplets_chain, args, llm_input_len, llm_output_len)

        # write to file
        with file_write_lock:
            write_to_file(args, question_id, question, answer, topic_entity_chain, reference_triplets_chain, answer_chain, response_chain)
    
    with queue_lock:
        llm_io_lens.put([llm_input_len, llm_output_len])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="webqsp", help="webqsp / cwq")
    parser.add_argument("--use_golden_topic", action="store_true", help="use golden topic entities or not")
    parser.add_argument("--use_relation_retrieval", action="store_true", help="retrieval strategy based on relations or fact triples/quadruples")
    parser.add_argument("--llm_api_key", type=str)
    parser.add_argument("--retrieve_width", type=int, default=10, help="number of facts retained by the retriever and as candidates for reranker")
    parser.add_argument("--reference_num", type=int, default=3, help="number of references for simple question answering")
    parser.add_argument("--max_decompose_depth", type=int, default=1)
    parser.add_argument("--elq_threshold", type=float, default=-4.5, help="threshold for entity linker ELQ")
    parser.add_argument("--batch_size", type=int, default=200, help="batch size for multithreading")
    args = parser.parse_args()

    print("Preparing data...")
    question_ids, questions, topic_entity_names, topic_entity_ids = prepare_data(args.dataset, args.use_golden_topic)

    print("Preparing entity linker...")
    elq_start_time = time.perf_counter()
    entity_linker_model, entity_linker_args = prepare_entity_linker(args.elq_threshold)
    elqId2wikiId = entity_linker_model[6]
    print("\telq_load_time = {}s".format(time.perf_counter() - elq_start_time))

    print("Start reasoning...")

    threads = []
    llm_io_lens = queue.Queue()

    for index in tqdm(range(math.ceil(len(questions) / args.batch_size))):
        start_i = index * args.batch_size
        end_i = min((index + 1) * args.batch_size, len(questions))
        
        question_ids_batch = question_ids[start_i:end_i]
        questions_batch = questions[start_i:end_i]
        topic_entity_names_batch = topic_entity_names[start_i:end_i]
        topic_entity_ids_batch = topic_entity_ids[start_i:end_i]

        thread = threading.Thread(target=process_batch, args=(index, questions_batch, question_ids_batch, topic_entity_names_batch, topic_entity_ids_batch, entity_linker_model, entity_linker_args, elqId2wikiId, args, llm_io_lens))
        threads.append(thread)
        thread.start()
    
    for thread in threads:
        thread.join()
    
    llm_input_len = 0
    llm_output_len = 0
    while not llm_io_lens.empty():
        llm_io_len_batch = llm_io_lens.get()
        llm_input_len += llm_io_len_batch[0]
        llm_output_len += llm_io_len_batch[1]

    print("llm_input_tokens_num = {}, llm_output_tokens_num = {}".format(llm_input_len, llm_output_len))
    print("Finish reasoning")