import argparse
import os
import re
import json
import random
import evaluate
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
from eval.clutrr.examplars import EXAMPLARS 
import openai


exact_match = evaluate.load("exact_match")
# Pattern to match substrings starting with "Therefore" or "based on"
pattern = r"(Hence, .*?\.|Therefore .*?\.|Based on .*?\.|we can conclude that .*?\.|Summarized Answer: .*?\.)"

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)

task = "3"
# save_dir = "results/entailment_bank/comp/ours3"
save_dir = "results/entailment_bank/comp/chatgpt3"
# save_dir = "results/entailment_bank/comp/chatgpt3_v2"
# save_dir = "results/entailment_bank/comp/ours4"
# save_dir = "results/entailment_bank/comp/chatgpt4"
# save_dir = "results/entailment_bank/comp/chatgpt4_v2"
use_chat_format = True
# openai_engine = "gpt-3.5-turbo"
openai_engine = "gpt-4"

if not os.path.exists(os.path.join(save_dir, task)):
    os.makedirs(os.path.join(save_dir, task), exist_ok=True)

predictions = []
with open(os.path.join("results/entailment_bank/chatgpt3_chat/{}".format(task), "predictions.jsonl")) as fin:
# with open(os.path.join("results/entailment_bank/ours3_chat_v7/{}".format(task), "predictions.jsonl")) as fin:
    for line in fin:
        example = json.loads(line)
        predictions.append({
            "story": example['story'],
            "question": example["question"],
            "answer": example["answer"].strip(),
            "prediction": example["model_output"].strip().split("\n\n")[-1].strip()
        })


print("Loading data...")

instruct = "Compare whether the two answers to the same question express the same meaning.\n"     
prompts = []
for example in predictions:
    if use_chat_format:
        # prompt = "<|user|>\n" + instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\n<|assistant|>\n" + "Answer:"
        prompt = "<|user|>\n" + instruct + "Question: " + example["story"] + "\n" + example["question"] + "\n" + "Answer 1: "+ example["answer"] + "\n" + "Answer 2: " + example["prediction"].strip() + "\n<|assistant|>\n" + "Please respond with either \"yes\" or \"no.\""

    else:
        # prompt = instruct + prompt_prefix + "Story: "+ example['story'] + "\n" + prefix + example["question"].strip() + "\nAnswer:"
        prompt = instruct + "Question: " + example["story"] + "\n" + example["question"] + "\n" + "Answer 1: "+ example["answer"] + "\n" + "Answer 2: " + example["prediction"].strip() + "\nPlease respond with either \"yes\" or \"no.\""
    
    prompts.append(prompt)


instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
results = query_openai_chat_model(
    engine=openai_engine,
    instances=instances,
    batch_size=20,
    output_path=os.path.join(save_dir, task, f"openai_results.jsonl"),
)
outputs = [result["output"] for result in results]

ACC = 0
for output in outputs:        
    print("Calculating accuracy...")
    
    if "yes" in output.lower():
        print ("match")
        ACC+=1

    else:
        print ("unmatch")
    print ("output", output)

print ("ACC", ACC)
em_score = ACC/len(predictions)
print(f"Exact match : {em_score}")
            
res= [{
    "story": example["story"],
    "question": example["question"],
    "answer": example["answer"],
    "prediction": example["prediction"],
    "model_output": output,
} for example, output in zip(predictions, outputs)]

with open(os.path.join(save_dir, task, f"predictions.jsonl"), "w") as fout:
    for pred in res:
        fout.write(json.dumps(pred) + "\n") 

with open(os.path.join(save_dir, task, "metrics.json"), "w") as fout:
    json.dump({
        "exact_match": em_score
    }, fout, indent=4)
