import json
import os
import re

def all_words_in_string(word_list, target_string):
    return all(word in target_string for word in word_list)

dataset = "entailment_bank"
task = "2"

# Open the file
cot = []
with open('results/{}/chatgpt4_chat/{}/predictions.jsonl'.format(dataset, task), 'r') as f:
    for line in f:
        # Load each line as a JSON object and add to the list
        cot.append(json.loads(line))
ours = []
with open('results/{}/ours4_chat_v7/{}/predictions.jsonl'.format(dataset, task), 'r') as f:
    for line in f:
        # Load each line as a JSON object and add to the list
        ours.append(json.loads(line))

if len(cot) != len(ours):
    print ("the length of dataset unmatched")
    
# Correctly classified by COT but not ours
res1 = []

# Correctly classified by ours but not COT
res2 = []
        
for idx in range(len(cot)):
    item_cot = cot[idx]
    item_ours = ours[idx]
    # print ("item_cot")
    # print (cot[idx])
    # print ("item_ours")
    # print (ours[idx])
    
    ans = item_cot["answer"].lower().replace(".", "") 
    word_list = ans.split(" ")

    if all_words_in_string(word_list, item_cot["prediction"].lower()):   
        flag_cot = True
        # print ("----------------flag_cot = True----------------")
        # # print (option_list)
        # # print (option_content_list)
        # print (option)
        # print (option_content)
        # print (item_cot)
    else:
        flag_cot = False
        # print ("----------------flag_cot = False----------------")
        # # print (option_list)
        # # print (option_content_list)
        # print (option)
        # print (option_content)
        # print (item_cot)
        
    if all_words_in_string(word_list, item_ours["prediction"].lower()):      
        # print ("----------------flag_ours = True----------------")
        # # print (option_list)
        # # print (option_content_list)
        # print (option)
        # print (option_content)
        # print (item_ours)
        flag_ours = True
    else:
        # print ("----------------flag_ours = False----------------")
        # # print (option_list)
        # # print (option_content_list)
        # print (option)
        # print (option_content)
        # print (item_ours)
        flag_ours = False
        
    # Correctly classified by COT but not ours
    if flag_cot == True and flag_ours ==False:
        res1.append((idx, item_cot, item_ours))

    # Correctly classified by ours but not COT
    if flag_cot == False and flag_ours ==True:
        res2.append((idx, item_cot, item_ours))

print ("-----------------------Correctly classified by COT but not ours-------------------------")    
print ("length: {}".format(len(res1)))    
for idx, cot_item, our_item in res1:
    print ("-----------------------{}-------------------------".format(idx))
    print ("story at {}: {}".format(idx, cot_item["story"]))
    print ("question at {}: {}".format(idx, cot_item["question"]))
    print ("ans at {}: {}".format(idx, cot_item["answer"]))
    print ("cot: {}".format(cot_item["model_output"]))
    print ("\n")
    print ("ours: {}".format(our_item["model_output"]))
    # print ("cot: {}".format(cot_item["prediction"]))
    # print ("\n")
    # print ("ours: {}".format(our_item["prediction"]))

# print ("-----------------------Correctly classified by ours but not COT-------------------------")     
# print ("length: {}".format(len(res2)))        
# for idx, cot_item, our_item in res2:
#     print ("-----------------------{}-------------------------".format(idx))
#     print ("question at {}: {}".format(idx, cot_item["story"]))
#     print ("question at {}: {}".format(idx, cot_item["question"]))
#     print ("ans at {}: {}".format(idx, cot_item["answer"]))
#     print ("cot: {}".format(cot_item["model_output"]))
#     print ("\n")
#     print ("ours: {}".format(our_item["model_output"]))
#     # print ("cot: {}".format(cot_item["prediction"]))
#     # print ("\n")
#     # print ("ours: {}".format(our_item["prediction"]))