
import json
import os
from datasets import load_dataset
import re
import pandas as pd

template = "What is the family relationship between: [{}] with [{}]?"

tasks = ["1.3","1.4", "1.5", "1.6", "1.7", "1.8", "1.9", "1.10"]

for task_name in tasks:
    print ("task", task_name)
    df = pd.read_csv("/Users/keweicheng/Downloads/open-instruct-main/clutrr1/{}_test.csv".format(task_name))
    
    task = []
    clean_task = []
    
    ans = []
    cnt = 0
    with open("results/clutrr/ans/chatgpt3_chat/{}/predictions.jsonl".format(task_name)) as fin:
        for line in fin:
            example = json.loads(line)
            ans.append({
                "story": example['story'].strip(),
                "question": example["question"],
                "prediction": example["prediction"].strip()
            })
            # if "and" in example["prediction"].strip():
            #     print ("two answer")
            
    for index, item in df.iterrows():
        # print(f"Row {index} data:")
        # print(row, "\n")
        matches = re.findall(r"'(.*?)'", item["query"])
        h,t = matches[0], matches[1]        
        if ans[index]["question"] == template.format(h, t) and ans[index]["story"]!="[]":
            task.append(
            {"story": item['story'],
            "question": template.format(h, t),
            "answer": ans[index]["prediction"]})
                    
            clean_task.append(
            {"story": item['clean_story'],
            "question": template.format(h, t),
            "answer": ans[index]["prediction"]})
        else:
            print ("unmatched")
            print (ans[index]["story"])
        
    print ("Save story")
    fout = open(os.path.join("data/eval/clutrr/story", "{}.jsonl".format(task_name)), "w")
    cnt= 0 
    for item in task:    
        fout.write(json.dumps(item) + "\n") 
        cnt+=1
    print ("cnt", cnt)
    
    print ("Save clean story")
    fout = open(os.path.join("data/eval/clutrr/clean_story", "{}.jsonl".format(task_name)), "w")
    cnt= 0 
    for item in clean_task:    
        fout.write(json.dumps(item) + "\n") 
        cnt+=1
    print ("cnt", cnt)
        
    


        