import os
import json
import re

def is_option_correct(text, option):
    # Preprocess the text (optional, based on the nature of the text)
    text = text.lower()

    # Define patterns to match
    patterns = [
        f".*{option}.*correct",
        f".*correct.*{option}",
        # Add more patterns as needed
    ]

    # Search for the patterns in the text
    for pattern in patterns:
        if re.search(pattern, text):
            return True

    return False


def is_option_content_correct(text, option_content):
    # Preprocess the text (optional, based on the nature of the text)
    text = text.lower()
    
    words_to_replace = [["must be", "is"], ["must be", "are"], ["must have", ""], [", must be", ""]]
    statements = []
    for word_pairs in words_to_replace:
        w1,w2 = word_pairs[0], word_pairs[1]
        if w1 in text:
            statements.append(text.replace(w1,w2))
            
    print ("text\n", text)
    print ("statements\n", statements)
    print ("option_content\n", option_content)
    
    for statement in statements:
        if option_content in statement:
            return True

    return False

def separate_context_question(sentence):
    # Split the sentence at the last period to separate context and question
    split_sentence = sentence.rsplit('.', 1)
    
    context = split_sentence[0] + "."
    
    # Add a check to ensure there's a second part after splitting
    question = split_sentence[1].strip() if len(split_sentence) > 1 else None

    return context, question

task = "predictions"
predictions = []
# with open(os.path.join("results/bbh/chatgpt4_chat/penguins_in_a_table", "{}.jsonl".format(task))) as fin:
# with open(os.path.join("results/bbh/ours3_v7/tracking_shuffled_objects_seven_objects", "{}.jsonl".format(task))) as fin:
with open(os.path.join("results/bbh/ours4_v7/logical_deduction_five_objects", "{}.jsonl".format(task))) as fin:
# with open(os.path.join("results/bbh/ours4_v7/penguins_in_a_table", "{}.jsonl".format(task))) as fin:
    for line in fin:
        example = json.loads(line)
        predictions.append({
            "input": example['input'],
            "target": example["target"].strip(),
            "prediction": example["prediction"].strip(),
            "model_output": example["model_output"].strip()
        })

ACC = 0
idx = 0
for example in predictions:
    print ("idx", idx)
    
    # input = example['input'].split("Options:")
    # context = input[0]
    # option = input[1]
    
    # desription, question = separate_context_question(context)
    # print ("desription", desription)
    # print ("question", question)
    # print ("option", option)
    
    option = example["target"].strip()
    # option_pattern = re.escape(option) + r" (\w+)"
    option_pattern =  re.escape(option) + r" ([^\n]+)"
    option_match = re.search(option_pattern, example['input'])
    option_content = option_match.group(1) if option_match else None    
    
    # option_list = example["target"].lower().split(" ")
    # option_content_list = option_content.lower().split(" ")
    option = example["target"].lower()
    option = option.replace("(", "").replace(")", "")
    
    option_content = option_content.strip().lower()
    
    # if all_words_in_string(option_list, example["prediction"].lower()) or all_words_in_string(option_content_list, example["prediction"].lower()) :
    # if option in example["prediction"].lower() 
    if is_option_correct(example["prediction"], option):
        ACC+=1
        # print ("match")
    elif option_content in example["prediction"].lower():      
        ACC+=1
        # print ("match")
    elif is_option_content_correct(example["prediction"], option_content):
        ACC+=1
        # print ("match")
    else:
        # print ("idx", idx)
        print ("unmatch")
        print ("target", example["target"])
        print ("option_content", option_content)
        print ("prediction", example["prediction"])
    
    idx+=1

# # dataset = ["web_of_lies","navigate"]
# ACC=0
# idx = 0
# for example in predictions:
#     print ("idx", idx)
    
#     if example["target"].lower().strip() in example["model_output"].lower():
#         ACC+=1
#         print ("match")
#     else:
#         # print ("idx", idx)
#         print ("unmatch")
#     print ("target", example["target"])
#     print ("prediction", example["model_output"])
    
#     idx+=1
        

em_score = ACC/len(predictions)
print(f"Exact match : {em_score}")
