import argparse
import os
import re
import json
import random
import evaluate
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
from eval.clutrr.examplars import EXAMPLARS 
import openai


exact_match = evaluate.load("exact_match")
# Pattern to match substrings starting with "Therefore" or "based on"
pattern = r"(Hence, .*?\.|Therefore .*?\.|Based on .*?\.|we can conclude that .*?\.)"


def main(args):
    # random.seed(42)
    task = args.task

    print("Loading data...")
    test_data = []
    # test.jsonl
    with open(os.path.join(args.data_dir, "{}.jsonl".format(task))) as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "story": example['story'],
                "question": example["question"],
                "answer": example["answer"].strip()
            })
        
    # # some numbers are in the `x,xxx` format, and we want to remove the comma
    # for example in test_data:
    #     example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    #     assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

    if args.max_num_examples and len(test_data) > args.max_num_examples:
        test_data = random.sample(test_data, args.max_num_examples)
        

    if not os.path.exists(os.path.join(args.save_dir, task)):
        os.makedirs(os.path.join(args.save_dir, task), exist_ok=True)
    
    prefix = "Follow the instructions and think step by step through inductive reasoning to answer the question. DO NOT directly provide the answer without giving an explanation for the answer. Question: "  
    # prefix = "You are an executor navigating through the constructed graph. Follow the instructions and think step by step to answer the question: "   
    # prefix = "Please follow the instructions and think step by step, recording each stage of your thinking to answer the question: "   
    # # prefix = "Based on the story, through inductive reasoning think step by step to answer the question: "   
    global EXAMPLARS
    if args.n_shot:
        if len(EXAMPLARS) > args.n_shot:
            EXAMPLARS = random.sample(EXAMPLARS, args.n_shot)
        demonstrations = []
        for example in EXAMPLARS:
            if args.no_example:
                demonstrations.append(
                   "Story: " + example['story'] + "\n" + prefix + example["question"] + "\n" + "Answer: " + example["short_answer"]
                )
            else:
                demonstrations.append(
                    "Story: " + example["story"] + "\n" + prefix  + example["question"] + "\n" + "Answer: " + example["answer"]
                )
        prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
    else:
        prompt_prefix = "Answer the following question.\n\n"
        
    # v6
    instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will pose a question. This question can be transformed into a triple (s, ?, o), where your primary task is to determine the missing relation (\'?\') that links the subject entity (\'s\') to the object entity (\'o\'). To begin, focus on the subject entity in this triple and choose the most relevant facts to expand from it. Step by step, progress towards the object entity, ensuring that each selected fact contributes to creating a link between the subject and object entities. Finally, utilize the established connection between the subject and object entities to answer the question.\n"      
    # v5
    # instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will pose a question. This question can be transformed into a triple (s, ?, o), where your primary task is to determine the missing relation (‘?’) that links the subject entity (‘s’) to the object entity (‘o’). Starting from the subject entity of this triple and select the most pertinent facts related to it. Gradually progress towards the object entity step by step, linking each step with relevant information. Ensure that each fact chosen helps to form a clear connection between the subject and object entities. Finally, utilize the established connection between the subject and object entities to answer the question.\n"    
    # v4
    # instruct = "Given a story, please think step by step to answer the question through inductive reasoning. Start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Please convert the question in the form of a triplet. Starting from the subject entity in the triplet, please select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"     
    # instruct = "Given a story, start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Please convert the question in the form of a triplet. Starting from the subject entity in the triplet, please select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"     
    # v3
    # instruct = "Given a story, please think step by step to answer the question through inductive reasoning. I will post a question and I want you to start from the subject entity in the question, select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"     
    # v2
    # instruct = "I want you to act as an intelligent agent. Given a story, construct a KG by extract facts from the story sentence by sentence. I will ask you a question. Please convert the question in form of triple. Starting from the subject entity in the triple, select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"
    # v1
    # instruct = "I want you to act as a reinforcement agent. Given a story, construct a KG by extract facts from the story sentence by sentence. I will ask you a question. Please convert the question in form of triple. Starting from the subject entity in the triple, select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"
    # instruct = "I'd like you to function as a reinforcement agent. When provided with a story, I'll pose a question to you. First, pinpoint every entity within the question. Beginning with the main entity,  choose the most relevant sentences related to the main entity. Then, choose another entity in the same sentences and infer the relation between main entity and the selected entity. Conduct this progress recursively until reach the end entity.\n"
    # instruct = "I want you to serve as a smart planner. When I give you a story, I will ask you a question. Start by identifying all entities in the question. Begin with the main entity and select the most relevant sentences about it. Then, pick another entity from the same sentences and deduce the relationship between the main entity and the chosen entity. Repeat this process recursively with the selected entity until you are able to answer the question. Summarize the ultimate answer concisely.\n"
    prompts = []
    for example in test_data:
        if args.use_chat_format:
            # prompt = "<|user|>\n" + instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\n<|assistant|>\n" + "Answer:"
            prompt = "<|user|>\n" + instruct + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\n<|assistant|>\n" + "Answer:"
            # prompt = "<|user|>\n" + instruct + "Story: "+ example['story'] + "\n" + "Question: "+ example["question"].strip() + "\n<|assistant|>\n" + "Answer:"    
            # prompt = "<|user|>\n" + "Story: "+ example['story'] + "\n" + instruct +  "Question: "+ example["question"].strip() + "\n<|assistant|>\n" + "Answer:"    
 
        else:
            # prompt = instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\nAnswer:"
            prompt = instruct + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\nAnswer:"
            # prompt = instruct + "Story: "+ example['story'] + "\n" + "Question: "+ example["question"].strip() + "\nAnswer:"
            # prompt = "Story: "+ example['story'] + "\n" + instruct + "Question: "+ example["question"].strip() + "\nAnswer:"
        
        prompts.append(prompt)

    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        model, tokenizer = load_hf_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path, 
            tokenizer_name_or_path=args.tokenizer_name_or_path, 
            load_in_8bit=args.load_in_8bit, 
            load_in_half=True,
            gptq_model=args.gptq
        )
        new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
        outputs = generate_completions(
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=512,
            batch_size=args.eval_batch_size,
            stop_id_sequences=[[new_line_token]]
        )
    else:
        instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            batch_size=args.eval_batch_size if args.eval_batch_size else 10,
            output_path=os.path.join(args.save_dir, task, f"openai_results.jsonl"),
            temperature = 0,
        )
        outputs = [result["output"] for result in results]

    predictions = []
    # for output in outputs:
    #     # replace numbers like `x,xxx` with `xxxx`
    #     output = re.sub(r"(\d),(\d)", r"\1\2", output)
    #     numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
    #     if numbers:
    #         predictions.append(numbers[-1])
    #     else:
    #         predictions.append(output)
    
    for output in outputs:
        # only keep the first part of the output - this is mainly for vanilla language models.
        output = output.strip().split("\n\n")[-1].strip()

        # print ("output", output)
        
        # extract the first answer after `So the answer is` and before the next period.
        # if there is no such answer, we will just use the raw output.
        results = re.findall(pattern, output, re.IGNORECASE)
        # print ("results", results)
        if len(results) >= 1:
            prediction = results[-1]
        else:
            # print ("raw output")
            prediction = output.strip()
            

        predictions.append(prediction)

    print("Calculating accuracy...")
    targets = [example["answer"].lower() for example in test_data]

    # em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    # print(f"Exact match : {em_score}")

    ACC = 0
    for i in range (len(predictions)):
        rels = set(" "+ rel.strip() for rel in targets[i].strip().split(" and "))
        truth = rels.copy()
        for rel in rels:
            if "-in-law" not in rel:
                truth.add(rel.strip()+"-in-law ")
            else:
                truth.add(rel.replace("-in-law", ""))
        # print ("truth", targets[i], truth)
        
        # print ("pred", predictions[i])

        compare = any(word in predictions[i].lower() for word in truth)
        
        # print ("compare", compare)
        
        ACC+=int(compare)
    
        
    em_score = ACC/len(predictions)
    print(f"Exact match : {em_score}")
            
    predictions = [{
        "story": example["story"],
        "question": example["question"],
        "answer": example["answer"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]

    with open(os.path.join(args.save_dir, task, f"predictions.jsonl"), "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n") 
    
    with open(os.path.join(args.save_dir, task, "metrics.json"), "w") as fout:
        json.dump({
            "exact_match": em_score
        }, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="data/eval/clutrr/story")
    parser.add_argument("--max_num_examples", type=int, default=200, help="maximum number of examples to evaluate.")
    parser.add_argument("--save_dir", type=str, default="results/clutrr/story/ours4_v6")
    # parser.add_argument("--save_dir", type=str, default="results/clutrr/story/ours3_1shot_gpt4_2")
    parser.add_argument("--task", type=str, default="1.3")
    parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
    parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
    parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
    parser.add_argument("--no_example", action="store_true", help="If given, we're evaluating a model without chain-of-thought.")
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
    parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
    parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")
    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)

