import argparse
import os
import re
import json
import tqdm
import glob
import torch
import random
import evaluate
from eval.utils import load_hf_lm_and_tokenizer, generate_completions, query_openai_chat_model

exact_match = evaluate.load("exact_match")

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)

def separate_context_question(sentence):
    # Split the sentence at the last period to separate context and question
    split_sentence = sentence.rsplit('.', 1)
    
    context = split_sentence[0] + "."
    
    # Add a check to ensure there's a second part after splitting
    question = split_sentence[1].strip() if len(split_sentence) > 1 else None

    return context, question


def main(args):
    # random.seed(42)
    task = args.task

    print("Loading data...")
    test_data = []
    # test.json

    with open(os.path.join(args.data_dir, "{}.json".format(task))) as fin:
            data = json.loads(fin.read())
            for example in data["examples"]:
                test_data.append({
                    "input": example['input'],
                    "target": example["target"].strip()
                })
        
    # # some numbers are in the `x,xxx` format, and we want to remove the comma
    # for example in test_data:
    #     example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    #     assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

    if args.max_num_examples and len(test_data) > args.max_num_examples:
        test_data = random.sample(test_data, args.max_num_examples)
        

    if not os.path.exists(os.path.join(args.save_dir, task)):
        os.makedirs(os.path.join(args.save_dir, task), exist_ok=True)

    # global GSM_EXAMPLARS
    # if args.n_shot:
    #     if len(GSM_EXAMPLARS) > args.n_shot:
    #         GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot)
    #     demonstrations = []
    #     for example in GSM_EXAMPLARS:
    #         if args.no_cot:
    #             demonstrations.append(
    #                 "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
    #             )
    #         else:
    #             demonstrations.append(
    #                 "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
    #             )
    #     prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
    # else:'
    #     prompt_prefix = "Answer the following question.\n\n"

    # v7

    # # tracking_shuffled_objects_n_objects
    # instruct = "Instructions: First, create a knowledge graph by extracting facts from each sentence in the given input story. The graph should evolve as the story progresses. I will present several statements. Your primary task is to determine the correctness of these statements by converting them into triples (s, r, o). Assess each statement's validity against the knowledge graph as it stands at the story's conclusion.\n"    
    # prefix = "Follow the instructions and think step by step to select the correct statement from the given options using the option number. DO NOT directly provide the answer without giving an explanation for the answer."  

    # # reasoning_about_colored_objects, penguins_in_a_table
    # # instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will pose a question. If the question involves predicting tail entities, start with the subject and methodically follow the steps outlined in the question to identify the tail entity. In cases where the question requires counting objects, use the knowledge graph to accurately perform the count.\n"   
    # instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. The graph may evolve as the story progresses. Once this is done, I will pose a question. This question will require you to identify objects that meet specific criteria. Utilize the final state of the knowledge graph, as it exists at the end of the story, to provide the answer to the question.\n"
    # prefix = "Follow the instructions and think step by step to select the correct answer from the given options using the option number. DO NOT directly provide the answer without giving an explanation for the answer."  

    # # logical_deduction_n_objects
    # instruct = "Instructions: First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will provide several statements. Your primary task is to determine the correctness of these statements. To assess the validity of a statement, sort the objects in the graph and evaluate the correctness of each statement."
    # # instruct = "Instructions: First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will provide several statements. Your primary task is to determine the correctness of these statements. Convert each statement into a triple (s, r, o). To assess the validity of a statement, trace a logical path in the knowledge graph starting from the subject, moving step-by-step towards the object. Utilize this path to deduce a relationship between the subject and object, thereby determining the correctness of the triple.\n"
    # # instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will provide several statements. Your primary task is to determine the correctness of these statements against the knowledge graph. For tasks involving counting, perform the count utilizing the constructed knowledge graph.\n"
    # prefix = "Follow the instructions and think step by step to select the correct statement from the given options using the option number. DO NOT directly provide the answer without giving an explanation for the answer."  

    # # object_counting
    # instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. The graph may evolve as the story progresses. Once this is done, I will pose a question. This question will require you to identify objects that meet specific criteria. Utilize the final state of the knowledge graph, as it exists at the end of the story, to provide the answer to the question.\n"
    # prefix = "Follow the instructions and think step by step to select the correct answer from the given options using the option number. DO NOT directly provide the answer without giving an explanation for the answer. Question: "  
 
    # object_counting
    instruct = "Instructions: First, create a knowledge graph by extracting facts from each sentence in the given input story. The graph may evolve as the story progresses. Once this is done, I will pose a question. This question will require you to identify objects that meet specific criteria. Utilize the final state of the knowledge graph, as it exists at the end of the story, to provide the answer to the question.\n"
    prefix = "Follow the instructions and think step by step to select the correct answer from the given options using the option number. DO NOT directly provide the answer without giving an explanation for the answer. Question: "  

    
    # instruct = "Given an input, start with constructing a knowledge graph by extracting facts from the input sentence by sentence. Then, iteratively modify the knowledge graph based on the input to answer the question step by step.\n"
    # instruct = "Given an input, start with constructing an initial knowledge graph. Then, iteratively modify the knowledge graph based on the input to answer the question.\n"
    # instruct = "Given a story, start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Please convert the question in the form of a triplet. Starting from the subject entity in the triplet, please select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"
    # instruct = "Based on the input, start with constructing an initial knowledge graph that captures the relationships between objects. If further operations are performed on the objects, iteratively update the knowledge graph step by step with each subsequent input to answer the question.\n"
    # instruct = "Given an input, start with constructing an initial knowledge graph by extracting facts from the input sentence by sentence.  Then, modify the knowledge graph based on the input step by step to answer the question.\n"
    # instruct = "Based on the input, start with constructing an initial knowledge graph that captures the relationships between objects. If further operations are performed on the objects, iteratively update the knowledge graph step by step with each subsequent input to answer the question.\n"
    # instruct = "Based on the input, start with constructing an initial knowledge graph. If additional operations are performed on the objects, iteratively modify the knowledge graph step by step to answer the question.\n"
    # prefix = "Given the input, think step by step to answer the question using the option number.\n"
    # prefix = "Let's think step by step explicitly. Then, answer the question using the option number.\n"
    # prefix = "Let's think step by step to answer the question using the option number.\n"
    # prefix = "Think step by step to answer the question.\n"
    prompts = []
    for example in test_data:
        # input = example['input'].split("Options:")
        # context = input[0]
        # question = input[1]
        if args.use_chat_format:
            prompt = "<|user|>\n" + instruct +  "Input: "+ example['input'] + "\n" + prefix + "\n<|assistant|>\n" + "Answer:"
            # prompt = "<|user|>\n" + "Input: "+ example['input'] + "\n"  + instruct + "\n<|assistant|>\n" + "Answer:"
           
        else:
            prompt = instruct + "Input: "+ example['input'] + "\n" + prefix + "\nAnswer:"
            # prompt = "Input: "+ example['input'] + "\n" + instruct + "\nAnswer:"
        
        prompts.append(prompt)

    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        model, tokenizer = load_hf_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path, 
            tokenizer_name_or_path=args.tokenizer_name_or_path, 
            load_in_8bit=args.load_in_8bit, 
            load_in_half=True,
            gptq_model=args.gptq
        )
        new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
        outputs = generate_completions(
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=512,
            batch_size=args.eval_batch_size,
            stop_id_sequences=[[new_line_token]]
        )
    else:
        instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            batch_size=args.eval_batch_size if args.eval_batch_size else 10,
            output_path=os.path.join(args.save_dir, task, f"openai_results.jsonl"),
        )
        outputs = [result["output"] for result in results]

    predictions = []
    # for output in outputs:
    #     # replace numbers like `x,xxx` with `xxxx`
    #     output = re.sub(r"(\d),(\d)", r"\1\2", output)
    #     numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
    #     if numbers:
    #         predictions.append(numbers[-1])
    #     else:
    #         predictions.append(output)
    
    for output in outputs:
        # only keep the first part of the output - this is mainly for vanilla language models.
        output = output.strip().split("\n\n")[-1].strip()

        predictions.append(output)

    print("Calculating accuracy...")
    targets = [example["target"] for example in test_data]
    inputs = [example["input"] for example in test_data]

    # em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    # print(f"Exact match : {em_score}")
    
    # ACC = 0
    # idx = 0
    # for i in range(len(predictions)):
    #     print ("idx", idx)
        
    #     option = targets[i].strip()
    #     # option_pattern = re.escape(option) + r" (\w+)"
    #     option_pattern =  re.escape(option) + r" ([^\n]+)"
    #     option_match = re.search(option_pattern, inputs[i])
    #     option_content = option_match.group(1) if option_match else None    
        
    #     # option_list = targets[i].lower().split(" ")
    #     # option_content_list = option_content.lower().split(" ")
    #     option = targets[i].lower()
    #     option_content = option_content.lower()
    #     # if all_words_in_string(option_list, predictions[i].lower()) or all_words_in_string(option_content_list, predictions[i].lower()) :
    #     if option in predictions[i].lower() or option_content in predictions[i].lower():     
    #         ACC+=1
    #         print ("match")
    #     else:
    #         # print ("idx", idx)
    #         print ("unmatch")
    #     print ("target", targets[i])
    #     print ("option_content", option_content)
    #     print ("prediction", predictions[i])
        
    #     idx+=1
    
    
    ACC = 0
    idx = 0
    for i in range(len(predictions)):
        print ("idx", idx)
        
        if targets[i].lower().strip() in predictions[i].lower():
            ACC+=1
            print ("match")
        else:
            print ("unmatch")
        print ("target", targets[i])
        print ("prediction", predictions[i])
        
        idx+=1
        
    em_score = ACC/len(predictions)
    print(f"Exact match : {em_score}")
            
    predictions = [{
        "input": example["input"],
        "target": example["target"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]

    with open(os.path.join(args.save_dir, task, f"predictions.jsonl"), "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n") 
    
    with open(os.path.join(args.save_dir, task, "metrics.json"), "w") as fout:
        json.dump({
            "exact_match": em_score
        }, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--data_dir", type=str, default="data/eval/bbh/bbh/reasoning")
    parser.add_argument("--data_dir", type=str, default="data/eval/bbh/bbh")
    parser.add_argument("--max_num_examples", type=int, default=300, help="maximum number of examples to evaluate.")
    # parser.add_argument("--save_dir", type=str, default="results/bbh/ours4")
    parser.add_argument("--save_dir", type=str, default="results/bbh/ours4_v7")

    # parser.add_argument("--task", type=str, default="tracking_shuffled_objects_seven_objects")
    # parser.add_argument("--task", type=str, default="reasoning_about_colored_objects")
    # parser.add_argument("--task", type=str, default="penguins_in_a_table")
    parser.add_argument("--task", type=str, default="penguins_in_a_table")
    # parser.add_argument("--task", type=str, default="logical_deduction_five_objects")
    # parser.add_argument("--task", type=str, default="web_of_lies")
    parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
    parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
    parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
    parser.add_argument("--no_cot", action="store_true", help="If given, we're evaluating a model without chain-of-thought.")
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
    parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
    parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")
    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)
