import json
import os
import re

def is_option_correct(text, option):
    # Preprocess the text (optional, based on the nature of the text)
    text = text.lower()

    # Define patterns to match
    patterns = [
        f".*{option}.*correct",
        f".*correct.*{option}",
        # Add more patterns as needed
    ]

    # Search for the patterns in the text
    for pattern in patterns:
        if re.search(pattern, text):
            return True

    return False


def is_option_content_correct(text, option_content):
    # Preprocess the text (optional, based on the nature of the text)
    text = text.lower()
    
    words_to_replace = [["must be", "is"], ["must be", "are"], ["must have", ""], [", must be", ""]]
    statements = []
    for word_pairs in words_to_replace:
        w1,w2 = word_pairs[0], word_pairs[1]
        if w1 in text:
            statements.append(text.replace(w1,w2))
            
    # print ("text\n", text)
    # print ("statements\n", statements)
    # print ("option_content\n", option_content)
    
    for statement in statements:
        if option_content in statement:
            return True

    return False

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)

dataset = "bbh"
# task = "tracking_shuffled_objects_three_objects"
# task = "logical_deduction_three_objects"
task = "penguins_in_a_table"

# Open the file
cot = []
with open('results/{}/chatgpt4_chat/{}/predictions.jsonl'.format(dataset, task), 'r') as f:
    for line in f:
        # Load each line as a JSON object and add to the list
        cot.append(json.loads(line))
ours = []
with open('results/{}/ours4_v7/{}/predictions.jsonl'.format(dataset, task), 'r') as f:
    for line in f:
        # Load each line as a JSON object and add to the list
        ours.append(json.loads(line))

if len(cot) != len(ours):
    print ("the length of dataset unmatched")
    
# Correctly classified by COT but not ours
res1 = []

# Correctly classified by ours but not COT
res2 = []
        
for idx in range(len(cot)):
    item_cot = cot[idx]
    item_ours = ours[idx]
    # print ("item_cot")
    # print (cot[idx])
    # print ("item_ours")
    # print (ours[idx])
    
    option = item_cot["target"].strip()
    # option_pattern = re.escape(option) + r" (\w+)"
    option_pattern =  re.escape(option) + r" ([^\n]+)"
    option_match = re.search(option_pattern, item_cot['input'])
    option_content = option_match.group(1) if option_match else None    
    
    # option_list = example["target"].lower().split(" ")
    # option_content_list = option_content.lower().split(" ")
    option = item_cot["target"].lower()
    option = option.replace("(", "").replace(")", "")
    
    option_content = option_content.strip().lower()
    
    # print ("question", item_cot['input'])
    # print ("option\n", option)
    # print ("option_content\n", option_content)
    
    # item_cot
    if is_option_correct(item_cot["prediction"], option):
        flag_cot = True
        # print ("match")
    elif option_content in item_cot["prediction"].lower():      
        flag_cot = True
        # print ("match")
    elif is_option_content_correct(item_cot["prediction"], option_content):
        flag_cot = True
        # print ("match")
    else:
        flag_cot = False
        
    # print ("item_cot\n", item_cot["prediction"])
        
    # item_ours
    if is_option_correct(item_ours["prediction"], option):
        flag_ours = True
        # print ("match")
    elif option_content in item_ours["prediction"].lower():      
        flag_ours = True
        # print ("match")
    elif is_option_content_correct(item_ours["prediction"], option_content):
        flag_ours = True
        # print ("match")
    else:
        flag_ours = False
        
    # print ("item_ours\n", item_ours["prediction"])
    # print ("flag_ours\n", flag_ours)

        
    # Correctly classified by COT but not ours
    if flag_cot == True and flag_ours ==False:
        res1.append((idx, item_cot, item_ours))

    # Correctly classified by ours but not COT
    if flag_cot == False and flag_ours ==True:
        res2.append((idx, item_cot, item_ours))

print ("-----------------------Correctly classified by COT but not ours-------------------------")    
print ("length: {}".format(len(res1)))    
for idx, cot_item, our_item in res1:
    print ("-----------------------{}-------------------------".format(idx))
    print ("input at {}: {}".format(idx, cot_item["input"]))
    print ("ans at {}: {}".format(idx, cot_item["target"]))
    print ("cot: {}".format(cot_item["model_output"]))
    print ("\n")
    print ("ours: {}".format(our_item["model_output"]))
    # print ("cot: {}".format(cot_item["prediction"]))
    # print ("\n")
    # print ("ours: {}".format(our_item["prediction"]))

# print ("-----------------------Correctly classified by ours but not COT-------------------------")     
# print ("length: {}".format(len(res2)))        
# for idx, cot_item, our_item in res2:
#     print ("-----------------------{}-------------------------".format(idx))
#     print ("input at {}: {}".format(idx, cot_item["input"]))
#     print ("ans at {}: {}".format(idx, cot_item["target"]))
#     print ("cot: {}".format(cot_item["model_output"]))
#     print ("\n")
#     print ("ours: {}".format(our_item["model_output"]))
#     # print ("cot: {}".format(cot_item["prediction"]))
#     # print ("\n")
#     # print ("ours: {}".format(our_item["prediction"]))