import argparse
import json
from tqdm import tqdm

def prepare_results(args):
    dataset_name = {"webqsp": "WebQSP", "cwq": "CWQ"}
    setting_name = {"io": "IO", "cot": "CoT", "sc": "SC"}
    if args.dataset not in dataset_name:
        print("input error: wrong dataset name")
        return []
    result_path = "./output/gpt3.5-{}_test".format(dataset_name[args.dataset])
    if args.use_relation_retrieval:
        result_path += "-rel"
    if args.use_golden_topic:
        result_path += "-goldenTopic"
    if args.special_setting in ["io", "cot", "sc"]:
        result_path += "-" + setting_name[args.special_setting]
    result_path += ".jsonl"
    with open(result_path, "r") as f:
        results = f.readlines()
    return results


def prepare_answers(dataset_name):
    if dataset_name == "webqsp":
        answer_path = "./data/webqsp/WebQSP.test.json"
    elif dataset_name == "cwq":
        answer_path = "./data/cwq/CWQ_test_preprocessed.json"
    else:
        print("input error: wrong dataset name")
        return {}
    goldAnswers = {}
    qTypes = {}
    with open(answer_path, "r") as f:
        data = json.load(f)
        if dataset_name == "webqsp":
            for example in data["Questions"]:
                questionId = example["QuestionId"]
                answers = []
                for parse in example["Parses"]:
                    for answer in parse["Answers"]:
                        if answer["AnswerType"] == "Entity":
                            if answer["EntityName"] == "":
                                answers.append(answer["AnswerArgument"])
                            else:
                                answers.append(answer["EntityName"])
                        elif answer["AnswerType"] == "Value":
                            answers.append(answer["AnswerArgument"])
                if len(answers) == 0:
                    answers.append("")
                goldAnswers[questionId] = list(set(answers))
        else:
            for example in data:
                questionId = example["ID"]
                questionType = example["compositionality_type"]
                answers = []
                for answer in example["answer"]:
                    if answer["AnswerType"] == "Entity":
                        if answer["EntityName"] == "":
                            answers.append(answer["AnswerArgument"])
                        else:
                            answers.append(answer["EntityName"])
                    elif answer["AnswerType"] == "Value":
                        answers.append(answer["AnswerArgument"])
                if len(answers) == 0:
                    answers.append("")
                goldAnswers[questionId] = list(set(answers))
                qTypes[questionId] = questionType
    return goldAnswers, qTypes

def prepare_answers_topic(dataset_name):
    if dataset_name == "webqsp":
        answer_path = "./data/webqsp/simple-WebQSP_test.jsonl"
    elif dataset_name == "cwq":
        answer_path = "./data/cwq/simple-CWQ_test.jsonl"
    else:
        print("input error: wrong dataset name")
        return {}
    goldTopics = {}
    with open(answer_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            example = json.loads(line)
            questionId = example["Id"]
            goldTopics[questionId] = example["TopicEntityId"]
    return goldTopics

def match(myAnswer, myResponse, goldAnswerList):
    goldAnswerList = [ans.lower() for ans in goldAnswerList]

    myAnswer = myAnswer.lower()
    myAnswerList = myAnswer.split(";")
    for myAns in myAnswerList:
        myAns = myAns.strip()
        for goldAns in goldAnswerList:
            if (myAns == goldAns) or (goldAns in myAns) or (myAns != "" and myAns in goldAns):
                return True
    myAnswerList = myAnswer.split(",")
    for myAns in myAnswerList:
        myAns = myAns.strip()
        for goldAns in goldAnswerList:
            if (myAns == goldAns) or (goldAns in myAns) or (myAns != "" and myAns in goldAns):
                return True
    
    myResponse = myResponse.lower()
    for goldAns in goldAnswerList:
        if myAnswer == "" and ((myResponse == goldAns) or (goldAns in myResponse) or (myResponse in goldAns)):
            return True
    return False


def match_topic(myTopicList, goldTopicList):
    for myTop in myTopicList:
        for goldTop in goldTopicList:
            if myTop == goldTop:
                return True
    return False


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="webqsp", help="webqsp / cwq")
    parser.add_argument("--use_golden_topic", action="store_true", help="use golden topic entities or not")
    parser.add_argument("--use_relation_retrieval", action="store_true", help="retrieval strategy based on relations or fact triples/quadruples")
    parser.add_argument("--special_setting", type=str, default="null", help="null / io / cot / sc")
    args = parser.parse_args()

    myResults = prepare_results(args)
    goldAnswers, qTypes = prepare_answers(args.dataset)
    if args.dataset == "cwq":
        allType = set(qTypes.values())
    goldTopics = prepare_answers_topic(args.dataset)

    correctNum = 0
    oneStepReasoningNum = 0
    correctTopicNum = 0
    correctRefNum = 0
    if args.dataset == "cwq":
        totalNumType = {t:0 for t in allType}
        correctNumType = {t:0 for t in allType}

    for i in tqdm(range(len(myResults))):
        example = json.loads(myResults[i])
        if args.dataset == "cwq":
            totalNumType[qTypes[example["Id"]]] += 1
        
        # overall accuracy
        if any([setting == args.special_setting for setting in ["io", "cot", "sc"]]):
            myResponse = example["Response"]
        else:
            myResponse = example["ResponseChain"][-1][1]
        if match(example["Answer"], myResponse, goldAnswers[example["Id"]]):
            correctNum += 1
            if args.dataset == "cwq":
                correctNumType[qTypes[example["Id"]]] += 1
        
        if args.special_setting == "null":
            # topic extracting accuracy
            if example["AnswerChain"][0][0] == "True":
                oneStepReasoningNum += 1
                if match_topic(example["TopicEntityChain"][0][1], goldTopics[example["Id"]]):
                    correctTopicNum += 1
            elif match_topic(list(set([eid for t in example["TopicEntityChain"] if len(t) != 0 for eid in t[1]])), goldTopics[example["Id"]]):
                correctTopicNum += 1
            # reference accuracy
            refAnswers = []
            for ref in example["ReferenceTripletsChain"]:
                for triple in ref:
                    if args.use_relation_retrieval:
                        if triple.startswith("["):
                            refAnswers.extend(triple[2:].split(" ]")[0][2:-2].split(" > < "))
                            refAnswers.append(triple.split(" > [SEP] < ")[-1][:-2])
                        else:
                            refAnswers.append(triple.split(" > [SEP] < ")[0][2:])
                            refAnswers.extend(triple[:-2].split("[ ")[1][2:-2].split(" > < "))
                    else:
                        refAnswers.append(triple.split(" > [SEP] < ")[0][2:])
                        refAnswers.append(triple.split(" > [SEP] < ")[-1][:-2])
            if any([refAns == goldAns for refAns in refAnswers for goldAns in goldAnswers[example["Id"]]]):
                correctRefNum += 1

    eval_result = "Result on dataset: {}\n".format(args.dataset)
    eval_result += "\tTotal number of test examples: {}\n".format(len(myResults))
    eval_result += "\tAccuracy: {:.3f} ({}/{})".format(correctNum / len(myResults), correctNum, len(myResults))
    
    if args.special_setting == "null":
        eval_result += "\n\tOne step reasoning rate: {:.3f} ({}/{})\n".format(oneStepReasoningNum / len(myResults), oneStepReasoningNum, len(myResults))
        eval_result += "\tTopic extracting accuracy: {:.3f} ({}/{})\n".format(correctTopicNum / len(myResults), correctTopicNum, len(myResults))
        eval_result += "\tReference accuracy: {:.3f} ({}/{})".format(correctRefNum / len(myResults), correctRefNum, len(myResults))
    
    if args.dataset == "cwq":
        for t in sorted(list(allType)):
            eval_result += "\n{}: {:.3f}({}/{})".format(t, correctNumType[t]/totalNumType[t], correctNumType[t], totalNumType[t])
    
    print(eval_result)

    eval_result_path = "./output/eval_result-gpt3.5-{}".format(args.dataset)
    if args.use_relation_retrieval:
        eval_result_path += "-rel"
    if args.use_golden_topic:
        eval_result_path += "-goldenTopic"
    if args.special_setting in ["io", "cot", "sc"]:
        eval_result_path += "-" + args.special_setting
    eval_result_path += ".txt"
    with open(eval_result_path, "w") as f:
        f.write(eval_result + "\n")