import argparse
import os
import re
import json
import random
import evaluate
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
from eval.clutrr.examplars import EXAMPLARS 
import openai


exact_match = evaluate.load("exact_match")
# Pattern to match substrings starting with "Therefore" or "based on"
pattern = r"(Hence, .*?\.|Therefore .*?\.|Based on .*?\.|we can conclude that .*?\.|Summarized Answer: .*?\.)"

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)


def main(args):
    # random.seed(42)
    task = args.task

    print("Loading data...")
    test_data = []
    # test.jsonl
    with open(os.path.join(args.data_dir, "{}.jsonl".format(task))) as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "story": example['story'],
                "question": example["question"],
                "answer": example["answer"].strip()
            })
        
    # # some numbers are in the `x,xxx` format, and we want to remove the comma
    # for example in test_data:
    #     example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    #     assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"

    if args.max_num_examples and len(test_data) > args.max_num_examples:
        test_data = random.sample(test_data, args.max_num_examples)
        

    if not os.path.exists(os.path.join(args.save_dir, task)):
        os.makedirs(os.path.join(args.save_dir, task), exist_ok=True)

    # prefix = "Based on the story, through deductive reasoning think step by step to answer the question explicitly then summarize the answer in few words: "   
    # prefix = "Based on the story, think step by step to answer the question explicitly. Then, summarize the answer in few words: "  
    prefix = "Follow the instructions and think step by step through deductive reasoning to answer the question. DO NOT directly provide the answer without giving an explanation for the answer. Conclude by summarizing the answer in a few words. Question: "  
    global EXAMPLARS
    if args.n_shot:
        if len(EXAMPLARS) > args.n_shot:
            EXAMPLARS = random.sample(EXAMPLARS, args.n_shot)
        demonstrations = []
        for example in EXAMPLARS:
            if args.no_example:
                demonstrations.append(
                   "Story: " + example['story'] + "\n" + prefix + example["question"] + "\n" + "Answer: " + example["short_answer"]
                )
            else:
                demonstrations.append(
                    "Story: " + example["story"] + "\n" + prefix  + example["question"] + "\n" + "Answer: " + example["answer"]
                )
        prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
    else:
        prompt_prefix = "Answer the following question.\n\n"
        
    # # v3
    # instruct = "Given a story, start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Please convert the question in the form of a triplet. Starting from the subject entity in the triplet, please select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"  
    # v4  
    # instruct = "Given a story, start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Please convert the question in the form of a triplet. Starting from the subject entity in the triplet, please select the most promising facts to extend from the subject entity step by step to answer the question.\n"    
    # v5
    # instruct = "Given a story, start with constructing a knowledge graph by extracting facts from the story sentence by sentence. Then, I will post a question. Starting from the subject entity in the question, please select the most promising facts to extend from the subject entity until reaching the object entity step by step to answer the question.\n"    
    # v6
    # instruct = "Given a story, start with constructing a logical causal graph by extracting the premise and conclusion for each sentence sentence by sentence. Then, I will post a question. Use logical deduction, guided step by step by your graph, to arrive at the answer to the question. Answer the question using the exact wording from the original context.\n" 
    # v7
    instruct = "First, create a knowledge graph by extracting facts from each sentence in the given input story. Once this is done, I will pose a question.  This question can be transformed into a triple (s, r, ?), where your primary task is to determine the missing tail entity (\'?\') that connects the subject entity (\'s\') through the relation (\’r\’). Start by concentrating on the subject entity in this triple and follow a logical path within the knowledge graph. Progress step by step from the statement related to the subject, using a forward chaining process. At each step, combine the conclusions with the facts in the knowledge graph to deduce new conclusions. The final conclusion along this logical path will serve as the answer to the question.\n"
    prompts = []
    for example in test_data:
        if args.use_chat_format:
            # prompt = "<|user|>\n" + instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\n<|assistant|>\n" + "Answer:"
            prompt = "<|user|>\n" + instruct + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\n<|assistant|>\n" + "Answer :"

        else:
            # prompt = instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\nAnswer:"
            prompt = instruct + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\nAnswer:"
        
        prompts.append(prompt)

    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        model, tokenizer = load_hf_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path, 
            tokenizer_name_or_path=args.tokenizer_name_or_path, 
            load_in_8bit=args.load_in_8bit, 
            load_in_half=True,
            gptq_model=args.gptq
        )
        new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
        outputs = generate_completions(
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=512,
            batch_size=args.eval_batch_size,
            stop_id_sequences=[[new_line_token]]
        )
    else:
        instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            batch_size=args.eval_batch_size if args.eval_batch_size else 10,
            output_path=os.path.join(args.save_dir, task, f"openai_results.jsonl"),
            temperature = 0
        )
        outputs = [result["output"] for result in results]

    predictions = []
    # for output in outputs:
    #     # replace numbers like `x,xxx` with `xxxx`
    #     output = re.sub(r"(\d),(\d)", r"\1\2", output)
    #     numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
    #     if numbers:
    #         predictions.append(numbers[-1])
    #     else:
    #         predictions.append(output)
    
    for output in outputs:
        # only keep the first part of the output - this is mainly for vanilla language models.
        output = output.strip().split("\n\n")[-1].strip()

        # # print ("output", output)
        
        # # extract the first answer after `So the answer is` and before the next period.
        # # if there is no such answer, we will just use the raw output.
        # results = re.findall(pattern, output, re.IGNORECASE)
        # # print ("results", results)
        # if len(results) >= 1:
        #     prediction = results[-1]
        # else:
        #     # print ("raw output")
        #     prediction = output.strip()
            
        predictions.append(output)
        # predictions.append(prediction)

    print("Calculating accuracy...")
    targets = [example["answer"].lower().replace(".", "") for example in test_data]

    # em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    # print(f"Exact match : {em_score}")
    
    ACC = 0
    for i in range (len(predictions)):
        word_list = targets[i].split(" ")
        if all_words_in_string(word_list, predictions[i].lower()):
            ACC+=1
        else:
            print ("unmatch")
            print ("target", targets[i])
            print ("prediction", predictions[i])
            print (word_list)
    
        
    em_score = ACC/len(predictions)
    print(f"Exact match : {em_score}")
            
    predictions = [{
        "story": example["story"],
        "question": example["question"],
        "answer": example["answer"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]

    with open(os.path.join(args.save_dir, task, f"predictions.jsonl"), "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n") 
    
    with open(os.path.join(args.save_dir, task, "metrics.json"), "w") as fout:
        json.dump({
            "exact_match": em_score
        }, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="data/eval/entailment_bank")
    parser.add_argument("--max_num_examples", type=int, default=200, help="maximum number of examples to evaluate.")
    parser.add_argument("--save_dir", type=str, default="results/entailment_bank/ours4_chat_v7")
    # parser.add_argument("--save_dir", type=str, default="results/clutrr/story/ours3_1shot_gpt4_2")
    parser.add_argument("--task", type=str, default="1")
    parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
    parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
    parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
    parser.add_argument("--no_example", action="store_true", help="If given, we're evaluating a model without chain-of-thought.")
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
    parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
    parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")
    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)

