import json
import os
import sys
sys.path.append("/home/[USER]/workshop/wikihow")
os.chdir("/home/[USER]/workshop/wikihow")
import random

def run1(res_file, test_file, tag):
    with open(res_file, "r", encoding="utf-8") as f:
        res = json.load(f)

    with open(test_file, "r") as f:
        test_list = set()
        for line in f:
            test_list.add(line.strip())
        print(f"number of test data: {len(test_list)}")

    all_data = {}
    test_data = {}
    for k, v in res.items():
        if f"{k}\t{v['gold_goal']}" in test_list:
            test_data[k] = v
        else:
            all_data[k] = v

    all_k = list(all_data.keys())
    random.shuffle(all_k)
    train_k = all_k[len(test_data):]
    dev_k = all_k[:len(test_data)]
    train_data = {k: all_data[k] for k in train_k}
    dev_data = {k: all_data[k] for k in dev_k}

    print(f"train: {len(train_data)}, dev: {len(dev_data)}, test: {len(test_data)}")

    with open(f"./data/wikihow/gold.rerank.{tag}.train.json", "w+") as f:
        json.dump(train_data, f, indent=2)

    with open(f"./data/wikihow/gold.rerank.{tag}.dev.json", "w+") as f:
        json.dump(dev_data, f, indent=2)

    with open(f"./data/wikihow/gold.rerank.{tag}.test.json", "w+") as f:
        json.dump(test_data, f, indent=2)

def run2():
    tot = []
    with open("./data/wikihow/step_goal.json", "r") as f:
        d = json.load(f)
    step2goal = {}
    for v in d:
        for s in v['caption']:
            s = s if not s.endswith(".") else s[:-1].strip()
            step2goal[s] = v['task']
        tot.append(len(v['caption']))
    print(len(step2goal))
    print(sum(tot) / len(tot))
    with open("./data/wikihow/step2goal.json", "w+") as f:
        json.dump(step2goal, f, indent=2)

if __name__ == "__main__":
    # no finetune
    # res_file = "/projects/[SERVER]2/users/[USER]/para/wikihow/para_base_sz_all_base_0.0_-1.0_10.0_False_links.json"
    # tag = 'org'
    # finetune
    # res_file = "/projects/[SERVER]2/users/[USER]/para/wikihow/para_ft_base_sz_all_base_0.0_-1.0_10.0_False_links.json"
    # tag = 'ft'

    # no finetune
    n = 50
    res_file = f"/projects/[SERVER]2/users/[USER]/para/wikihow/para_base_sz_all_base_0.0_-1.0_10.0_{n}_links.json"
    tag = f'org.t{50}'
    test_file = "/projects/[SERVER]2/users/[USER]/para/scratch/gold.para.base.dev.txt"
    run1(res_file,  test_file, tag)

    # run2()