import argparse
import os
import re
import json
import tqdm
import glob
import torch
import random
import evaluate
from eval.utils import load_hf_lm_and_tokenizer, generate_completions, query_openai_chat_model


import argparse
import os
import re
import json
import random
import evaluate
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
from eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS
import openai


exact_match = evaluate.load("exact_match")

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)


def main(args):
    # random.seed(42)
    task = args.task

    print("Loading data...")
    test_data = []
    # test.json

    with open(os.path.join(args.data_dir, "{}.json".format(task))) as fin:
            data = json.loads(fin.read())
            for example in data["examples"]:
                test_data.append({
                    "input": example['input'],
                    "target": example["target"].strip()
                })
        
    # # some numbers are in the `x,xxx` format, and we want to remove the comma
    # for example in test_data:
    #     example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    #     assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

    if args.max_num_examples and len(test_data) > args.max_num_examples:
        test_data = random.sample(test_data, args.max_num_examples)
        

    if not os.path.exists(os.path.join(args.save_dir, task)):
        os.makedirs(os.path.join(args.save_dir, task), exist_ok=True)

    # global GSM_EXAMPLARS
    # if args.n_shot:
    #     if len(GSM_EXAMPLARS) > args.n_shot:
    #         GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot)
    #     demonstrations = []
    #     for example in GSM_EXAMPLARS:
    #         if args.no_cot:
    #             demonstrations.append(
    #                 "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
    #             )
    #         else:
    #             demonstrations.append(
    #                 "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
    #             )
    #     prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
    # else:
    #     prompt_prefix = "Answer the following question.\n\n"

    prefix = "Given the input, answer the question using yes or no.\n"  
    # suffix = "Let's think step by step.\n"
    # prefix = "Given the input, think step by step to answer the question. Summarize the answer using yes or no.\n"
    # prefix = "Given the input, answer the question based on the given options.\n"
    # prefix = "Given the input, think step by step to answer the question using the option number.\n"
    prompts = []
    for example in test_data:
        if args.use_chat_format:
            # v3
            # prompt = "<|user|>\n" + prefix +  "Input: "+ example['input'] + suffix + "\n" + "\n<|assistant|>\n" + "Answer:"
            # v2
            # prompt = "<|user|>\n" + "Input: "+ example['input'] + "\n" + prefix + "\n<|assistant|>\n" + "Answer:"
            prompt = "<|user|>\n" + prefix +  "Input: "+ example['input'] + "\n" + "\n<|assistant|>\n" + "Answer:"
            
        else:
            # v3
            # prompt = prefix + "Input: "+  example['input'] + suffix + "\n" + "\nAnswer:"
            # v2
            # prompt = "Input: "+  example['input'] + "\n" + prefix +  "\nAnswer:"
            prompt = prefix + "Input: "+  example['input'] + "\n" + "\nAnswer:"
        
        prompts.append(prompt)

    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        model, tokenizer = load_hf_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path, 
            tokenizer_name_or_path=args.tokenizer_name_or_path, 
            load_in_8bit=args.load_in_8bit, 
            load_in_half=True,
            gptq_model=args.gptq
        )
        new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
        outputs = generate_completions(
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=512,
            batch_size=args.eval_batch_size,
            stop_id_sequences=[[new_line_token]]
        )
    else:
        instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            batch_size=args.eval_batch_size if args.eval_batch_size else 10,
            output_path=os.path.join(args.save_dir, task, f"openai_results.jsonl"),
        )
        outputs = [result["output"] for result in results]

    predictions = []
    # for output in outputs:
    #     # replace numbers like `x,xxx` with `xxxx`
    #     output = re.sub(r"(\d),(\d)", r"\1\2", output)
    #     numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
    #     if numbers:
    #         predictions.append(numbers[-1])
    #     else:
    #         predictions.append(output)
    
    for output in outputs:
        # only keep the first part of the output - this is mainly for vanilla language models.
        output = output.strip().split("\n\n")[-1].strip()

        predictions.append(output)

    print("Calculating accuracy...")
    targets = [example["target"] for example in test_data]
    inputs = [example["input"] for example in test_data]

    # em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    # print(f"Exact match : {em_score}")
    
    # ACC = 0
    # idx = 0
    # for i in range(len(predictions)):
    #     # print ("idx", idx)
        
    #     option = targets[i].strip()
    #     option_pattern = re.escape(option) + r" (\w+)"
    #     option_match = re.search(option_pattern, inputs[i])
    #     option_content = option_match.group(1) if option_match else None    
        
    #     option_list = targets[i].lower().split(" ")
    #     option_content_list = option_content.lower().split(" ")
    #     if all_words_in_string(option_list, predictions[i].lower()) or all_words_in_string(option_content_list, predictions[i].lower()) :
    #         ACC+=1
    #         # print ("match")
    #     else:
    #         print ("idx", idx)
    #         print ("unmatch")
    #         print ("target", targets[i])
    #         print ("option_content", option_content)
    #         print ("prediction", predictions[i])
    #         # print (option_list, option_content_list)
        
    #     idx+=1
    
    ACC = 0
    idx = 0
    for i in range(len(predictions)):
        print ("idx", idx)
        
        if targets[i].lower().strip() in predictions[i].lower():
            ACC+=1
            print ("match")
        else:
            # print ("idx", idx)
            print ("unmatch")
        print ("target", targets[i])
        print ("prediction", predictions[i])
        
        idx+=1
    
        
    em_score = ACC/len(predictions)
    print(f"Exact match : {em_score}")
            
    predictions = [{
        "input": example["input"],
        "target": example["target"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]

    with open(os.path.join(args.save_dir, task, f"predictions.jsonl"), "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n") 
    
    with open(os.path.join(args.save_dir, task, "metrics.json"), "w") as fout:
        json.dump({
            "exact_match": em_score
        }, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--data_dir", type=str, default="data/eval/bbh/bbh/reasoning")
    parser.add_argument("--data_dir", type=str, default="data/eval/bbh/bbh")
    parser.add_argument("--max_num_examples", type=int, default=300, help="maximum number of examples to evaluate.")
    # parser.add_argument("--save_dir", type=str, default="results/bbh/chatgpt3_chat")
    parser.add_argument("--save_dir", type=str, default="results/bbh/chatgpt4_chat")
    # parser.add_argument("--task", type=str, default="logical_deduction_three_objects")
    # parser.add_argument("--task", type=str, default="tracking_shuffled_objects_three_objects")
    # parser.add_argument("--task", type=str, default="logical_deduction_seven_objects")
    parser.add_argument("--task", type=str, default="navigate")
    parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
    parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
    parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
    parser.add_argument("--no_cot", action="store_true", help="If given, we're evaluating a model without chain-of-thought.")
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
    parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
    parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")
    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)
