import os
import sys
sys.path.append("/home/[USER]/workshop/wikihow")
os.chdir("/home/[USER]/workshop/wikihow")

def func1():
    import ijson
    import json
    with open("/projects/[SERVER]2/users/[USER]/wikihow/howto100m/raw_caption2.json", "r") as f:
        with open("/projects/[SERVER]2/users/[USER]/wikihow/howto100m/raw_caption.jsonl", "w", encoding="utf-8") as fout:
            data = ijson.parse(f)
            cur_item = None
            tot = 0
            id = None
            for prefix, event, value in data:
                if prefix.strip() == "" and event == "map_key": # a new item
                    if cur_item is not None:
                        cur_item["text"] = " ".join(cur_item["text"]).replace("\n", " ")
                        json.dump(cur_item, fout)
                        fout.write("\n")
                        tot += 1
                    cur_item = {"id": value, "text": []}
                    id = value
                elif prefix == f"{id}.text.item":
                    cur_item["text"].append(value)

def func2():
    with open("/projects/[SERVER]2/users/[USER]/wikihow/howto100m/raw_caption.json", "r", encoding="utf-8") as f:
        with open("/projects/[SERVER]2/users/[USER]/wikihow/howto100m/raw_caption2.json", "w+") as fout:
            while True:
                c = f.read(10240)
                if ', NaN]' in c:
                    print(c)
                    c = c.replace(', NaN]', ']')
                    print(c)
                if '[NaN,' in c:
                    print(c)
                    c = c.replace('[NaN,', '[')
                    print(c)
                if ', NaN,' in c:
                    print(c)
                    c = c.replace(', NaN,', ',')
                    print(c)
                # if "NaN" in c:
                #     c = c.replace("NaN", '')
                fout.write(c)
                if c.strip() == '':
                    break

# sample a subset from the whole
def func3():
    import random
    import json
    t = random.choices(list(range(1238912)), k=5000)
    data = []
    with open("./data/howto100m/raw_caption.jsonl", "r", encoding="utf-8") as f:
        with open("./data/howto100m/selected_raw_caption.json", "w+", encoding="utf-8") as fout:
            for idx, line in enumerate(f):
                xx = json.loads(line)
                if idx in t:
                    data.append(xx)
            json.dump(data, fout, indent=2)

def func4():
    import json
    import spacy
    from external.autopunct.correct import correct
    punct_model = spacy.load("./external/autopunct/punct-model")
    caps_model = spacy.load("./external/autopunct/caps-model")
    with open("./data/howto100m/selected_raw_caption.json", "r", encoding="utf-8") as f:
        with open("./data/howto100m/selected_raw_caption.punct.json", "w+", encoding="utf-8") as fout:
            data = json.load(f)
            for idx, item in enumerate(data):
                if idx % 50 == 0:
                    print(idx)
                text = item["text"]
                punct_text = correct(text, punct_model, caps_model)
                item["text"] = punct_text
            json.dump(data, fout)


def func5():
    # sample a subset from howto100_v1 with ranking information
    with open("./data/howto100m/HowTo100M_v1.csv",  "r", encoding="utf-8") as fin:
        with open("./data/howto100m/high_rank.csv", "w+", encoding="utf-8") as fout:
            tot = 0
            fin.readline()
            for line in fin:
                tks = line.strip().split(",")
                assert len(tks) == 5
                rank = int(tks[-2])
                if rank <= 3:
                    tot += 1
                    fout.write(line)
    print(tot)

def func6():
    import json
    # get the captions
    with open("./data/howto100m/sampled/high_rank.csv", "r", encoding="utf-8") as fin:
        ids = set()
        for line in fin:
            tks = line.strip().split(",")
            ids.add(tks[0])
    print(len(ids))

    with open("./data/howto100m/caption.json", "r") as f:
        d = json.load(f)

    _d = {}
    for k, v in d.items():
        if k in ids:
            _d[k] = v

    print(len(_d))

    with open("./data/howto100m/sampled/caption.json", "w+") as f:
        json.dump(_d, f)


def func7():
    import json
    # generate data for train or test, one line a sample
    # get task id task name map
    with open("./data/howto100m/task_ids.csv",  "r", encoding="utf-8") as f:
        task_id_map = {}
        for line in f:
            tks = line.strip().split("\t")
            task_id_map[tks[0]] = tks[1]

    with open("./data/howto100m/sampled/caption.json", "r") as f:
        caption = json.load(f)

    data = []
    with open("./data/howto100m/sampled/high_rank.csv", "r", encoding="utf-8") as f:
        with open("./data/howto100m/sampled/all_in_one.json", "w", encoding="utf-8") as fout:
            for line in f:
                tks = line.strip().split(",")
                cur_caption = caption[tks[0]]
                cur_caption_text = " || ".join([x for x in cur_caption["text"] if isinstance(x, str)])
                cur_data = {}
                cur_data["id"] = tks[0]
                cur_data["category"] = tks[1:3]
                cur_data["rank"] = tks[3]
                cur_data["task"] = task_id_map[tks[4]]
                cur_data['caption'] = cur_caption_text
                data.append(cur_data)
            print(len(data))
            json.dump(data, fout, indent=2)

def func8():
    from collections import defaultdict
    # get the category of each task
    task_cat = defaultdict(set)
    with open("./data/howto100m/task_ids.csv", "r", encoding="utf-8") as f:
        id2task = {}
        task2id = {}
        for line in f:
            tks = line.strip().split("\t")
            id2task[tks[0]] = tks[1]
            task2id[tks[1]] = tks[0]
        print(len(id2task), len(task2id))

    with open("./data/howto100m/HowTo100M_v1.csv", "r", encoding="utf-8") as f:
        f.readline()
        for line in f:
            tks = line.strip().split(",")
            task_name = id2task[tks[-1]]
            task_cat[task_name].add(tks[1])
            task_cat[task_name].add(tks[2])

    with open("./data/howto100m/task_category.map", "w+", encoding="utf-8") as f:
        print(len(task_cat))
        for k, v in task_cat.items():
            v = [x for x in v if x.strip() != '']
            # print(len(v))
            f.write(f"{k} || {' || '.join(v)}\n")

def func91():
    import pickle
    import json
    import re
    with open("./data/wikihow/goal2steps_full.p", "rb") as f:
        d = pickle.load(f)
    all_d = []

    tasks = set()
    for k, v in d.items():
        if not isinstance(k, str):
            continue
        assert len(v.keys()) == 2, v.keys()
        task = k[7:] if k[:7] == 'How to ' else k
        flag = False
        for kkidx, (kk, vv) in enumerate(v['methods'].items()):
            flag = True
            tasks.add(task)
            cur_d = {'task': task, 'category': v['category'], 'method_idx': kkidx, 'method': kk}
            captions = [re.sub(r'\{.*?(http|png|jpg|www)+.*?\}', '', x) for x in vv]
            cur_d['caption'] = captions
            all_d.append(cur_d)
        if not flag:
            print(task)
    fsave = "./data/wikihow/step_goal.json"
    with open(fsave, "w", encoding="utf-8") as f:
        print(f"{len(d)}, {len(all_d)}")
        json.dump(all_d, f, indent=2)

    func10(fsave)

def func10(fpath):
    import json
    import random
    with open(fpath, "r", encoding="utf-8") as f:
        d = json.load(f)
    random.shuffle(d)
    with open(fpath.replace(".json", ".train.json"), "w", encoding="utf-8") as f:
        dd = d[:-5000]
        json.dump(dd, f, indent=2)
    with open(fpath.replace(".json", ".dev.json"), "w", encoding="utf-8") as f:
        dd = d[-5000:]
        json.dump(dd, f, indent=2)

def func92():
    import networkx as nx
    import json
    from tqdm import tqdm
    import pickle
    import re
    G = nx.Graph()

    with open("./data/wikihow/step_goal.json", "r") as f:
        data = json.load(f)

    for item in tqdm(data, desc='main'):
        title = item["task"]
        midx = item["method_idx"]
        G.add_node(title, type='goal')
        for sidx, step in enumerate(item["caption"]):
            G.add_node(step, type="step")
            G.add_edge(title, step, type='goal-step', method_idx=midx, step_idx=sidx)

    skip = 0
    add = 0
    with open('./data/wikihow/para_step_goal_links.json') as f:
        jobj = json.load(f)
        for step, things in tqdm(jobj.items(), desc='para'):
            if things["retrieved_goals"]:
                step = re.sub(r'\{.*?(http|png|jpg|www)+.*?\}', '', step) if step not in G.nodes else step
                for retrieved_goal, retrieved_goal_similarity in zip(things["retrieved_goals"], things["retrieved_goals_similarity"]):
                    if retrieved_goal not in G.nodes or step not in G.nodes:
                        skip += 1
                        continue
                    G.add_edge(step, retrieved_goal, type="step-goal-paraphrase", similarity=retrieved_goal_similarity)
                    add += 1

    print(f"skip {skip} nodes b.c. mismatch of the steps, add {add}")

    with open('./data/wikihow/wikihow_graph_v2.pkl', 'wb') as fw:
        pickle.dump(G, fw)


def func11():
    # select cat for each task
    import json
    from collections import defaultdict
    with open("./data/wikihow/step_goal.json", "r", encoding="utf-8") as f:
        d = json.load(f)
    task2cat = defaultdict(list)
    cat2task = defaultdict(list)
    with open("./data/howto100m/task_category.map", "r", encoding="utf-8") as f:
        for line in f:
            tks = line.strip().split(" || ")
            task, cat = tks[0], tks[1:]
            task2cat[task] = cat
            for cc in cat:
                cat2task[cc].append(task)

    k_list = set([x["task"] for x in d])
    tot = len(k_list)
    hit = 0
    for item in k_list:
        k = item.replace("How to ", "").strip()
        if k in task2cat:
            hit += 1
        else:
            print(k)
    print(hit, tot, hit/tot)

def func12():
    import json
    from collections import defaultdict
    with open("./data/wikihow/step_goal.json", "r", encoding="utf-8") as f:
        d = json.load(f)
    task_map = defaultdict(set)
    with open("./data/wikihow/task_category.map", "w+", encoding="utf-8") as f:
        for item in d:
            for c in item["category"]:
                task_map[item["task"]].add(c)
        for k, v in task_map.items():
            assert len(v) != 0
            f.write(f"{k} || {' || '.join(v)}\n")

def func13():
    # get steps from the para
    import networkx
    import json
    import pickle
    import random
    import numpy as np
    from tqdm import tqdm

    def get_steps(G, goal, method_idx):
        steps = [neighbor for neighbor in G.adj[goal] if G.edges[goal, neighbor]['type'] == f'goal-step' and G.edges[goal, neighbor]['method_idx'] == method_idx]
        # sort the steps by order
        steps = sorted(steps, key=lambda x: G.edges[goal, x]['step_idx'])
        return steps

    def get_para_goal(G, step):
        goal = [neighbor for neighbor in G.adj[step] if G.edges[step, neighbor]['type'] == 'step-goal-paraphrase']
        return goal

    def find_final_steps(step, G, remain):
        # find para
        para = get_para_goal(G, step)

        # no para
        if len(para) == 0:
            res = [step]
            mode = "[not find]"
        else:
            sim = [float(G.edges[step, x]['similarity']) for x in para]
            sim = [x / sum(sim) for x in sim]
            # sample with similarity
            para_goal = random.choices(para, weights=sim, k=1)[0]
            # only expand to this level
            if G.nodes[para_goal].get('type', '') != 'goal': # error cases, return the step directly
                res = [step]
                mode = "[error node type]"
            else:
                goal_method_num = find_method_num(G, para_goal)
                if goal_method_num <= 0:
                    res = [step]
                    mode = "[no method]"
                else:
                    # sample a method
                    method_idx = random.choice(list(range(goal_method_num)))
                    goal_steps = get_steps(G, para_goal, method_idx)
                    if remain == 0:
                        res = goal_steps
                        mode = f"[max depth of {para_goal}]"
                    else:
                        extend_steps = []
                        for cur_step in goal_steps:
                            extend_steps += find_final_steps(cur_step, G, remain=remain-1)
                        res = extend_steps
                        mode = f"[extend {para_goal}]"
        # print(remain, mode, step, res)
        return res

    def find_method_num(G, goal):
        max_midx = -100
        for neighbor in G.adj[goal]:
            if G.edges[goal, neighbor]['type'] == 'goal-step':
                max_midx = max(max_midx, G.edges[goal, neighbor]['method_idx'])
        return max_midx + 1

    def get_goal(G, goal):
        assert goal in G.nodes
        return goal

    max_depth = 2

    with open("./data/wikihow/wikihow_graph.pkl", "rb") as f:
        G = pickle.load(f)

    # get all goals
    with open("./data/wikihow/step_goal.json", "r") as f:
        step_goal = json.load(f)

    update = 0
    for pidx, sg in enumerate(tqdm(step_goal, disable=False)):
        goal = sg["task"]
        if G.nodes[goal]['type'] != 'goal':
            continue

        # traverse the tree
        final_steps = []
        for step in sg["caption"]:
            cur_final_steps = find_final_steps(step, G, remain=max_depth)
            final_steps +=  cur_final_steps

        sg["original_caption"] = sg["caption"]
        sg["caption"] = final_steps
        if sg['caption'] != sg['original_caption']:
            update += 1

        # if pidx == 4000:
        #     break

    print(f"{update}/{len(step_goal)} samples have newer captions")
    fsave = "./data/wikihow/step_goal.para.json"
    with open(fsave, "w+") as f:
        json.dump(step_goal, f, indent=2)
    func10(fsave)



def func14():
    import json
    from transformers import RobertaTokenizer
    tkz = RobertaTokenizer.from_pretrained('roberta-base')
    with open("./data/wikihow/step_goal.json", "r") as f:
        d = json.load(f)
    t = 0
    for item in d:
        cur_tk = tkz(' '.join(item['caption'].split(" || ")))
        if len(cur_tk['input_ids']) > 300:
            print(item['task'], len(cur_tk['input_ids']))
            t += 1
    print(t)

def func15():
    import json
    with open("./data/wikihow/step_goal.json", "r") as f:
        d = json.load(f)
    t = 0
    for item in d:
        ll = len(item["caption"].replace(" || ", " ").split())
        if ll > 200:
            print(item['task'])
            t += 1
    print(t)

def func16():
    import ast
    with open("./data/result/roberta.wikihow.aug.ep5.log", "r", encoding="utf-8") as f:
        f.readline()
        f.readline()
        tot = 0
        acc = 0
        for line in f:
            tks = line.strip().split("\t")
            gold = tks[1]
            pred = ast.literal_eval(tks[2])
            if gold in pred:
                assert gold.strip()
                acc += 1
                print(gold, pred)
            tot += 1
    print(acc, tot, acc/tot)

def func17():
    import json
    with open("./data/howto100m/sampled/all_in_one.json", "r") as f:
        aio = json.load(f)
        sampled_ids = set([x["id"] for x in aio])
    captions = {}
    with open("./data/howto100m/raw_caption.jsonl", "r") as f:
        for line in f:
            cur_d = json.loads(line)
            if cur_d["id"] in sampled_ids:
                captions[cur_d["id"]] = cur_d["text"]

    with open("./data/howto100m/sampled/raw_caption.json", "w+") as f:
        json.dump(captions, f)

    for item in aio:
        new_caption = captions[item["id"]]
        item["caption"] = new_caption
    with open("./data/howto100m/sampled/all_in_one_raw_caption.json", "w+") as f:
        json.dump(aio, f, indent=2)

def func18():
    import ast
    # eval captions <= max_len
    max_len = 100000
    recall = {1: 0, 3: 0, 5: 0, 10: 0}
    with open("./data/result/roberta.wikihow.aug.ep5.raw_cap.cut_head.log", "r") as f:
        f.readline()
        f.readline()
        tot = 0
        for line in f:
            tks = line.strip().split("\t")
            if len(tks) != 4:
                print(line)
                continue
            cap, gold, pred = tks[:3]
            pred = ast.literal_eval(pred)
            if len(cap.strip().split()) <= max_len:
                tot += 1
                g_index = pred.index(gold) if gold in pred else 10000
                recall = {k: recall[k] + int(k >= g_index + 1) for k in recall}
        recall = {k: f"{recall[k]}/{tot}={recall[k] / tot}" for k in recall}
        print(recall)

def func19():
    import matplotlib.pyplot as plt
    import ast
    from collections import Counter
    import numpy as np
    bucket = Counter()
    with open("./data/result/roberta.wikihow.aug.ep5.raw_cap.cut_head.log", "r") as f:
        f.readline()
        f.readline()
        for line in f:
            tks = line.strip().split("\t")
            if len(tks) != 4:
                print(line)
                continue
            cap, gold, pred = tks[:3]
            pred = ast.literal_eval(pred)
            cap_len = len(cap.strip().split())
            b_len = cap_len // 100
            if gold in pred:
                bucket[b_len] += 1

    print(bucket)

def func20():
    # compare output
    f1 = open("./data/result/roberta.wikihow.aug.ep5.raw_cap.cut_head.log", "r")
    f2 = open("./data/result/roberta.wikihow.aug.ep5.raw_cap.log", "r")
    f1.readline()
    f1.readline()
    f2.readline()
    f2.readline()
    tot = 0
    f3 = open("./data/result/tmp.log", "w+")
    for l1, l2 in zip(f1, f2):
        # print("cut:", l1.strip())
        # print("uncut: ", l2.strip())
        tks1 = l1.strip().split("\t")
        tks2 = l2.strip().split("\t")
        if len(tks1) != 4:
            continue
        # assert " ".join(tks1[0].strip().split()) in " ".join(tks2[0].strip().split())
        # if not " ".join(tks1[0].strip().split()) in " ".join(tks2[0].strip().split()):
        #     print(tks1[0])
        #     print(tks2[0])

        if tks1[-1] == 'True' and tks2[-1] == 'False':
            cap = tks2[0]
            cap_tks = cap.strip().split()
            if len(cap_tks) < 50:
                raise ValueError
            elif len(cap_tks) > 50 and len(cap_tks) < 512:
                # cut the beginning
                f3.write(f"< 512 ||| {' '.join(cap_tks[:50])} ||| {tks1[1]} ||| {tks1[2]} ||| {tks2[2]}\n")
            elif len(cap_tks) > 512:
                f3.write(f"> 512 ||| {' '.join(cap_tks[:50])} ||| {' '.join(cap_tks[512: 512 + 50])} ||| {tks1[1]} ||| {tks1[2]} ||| {tks2[2]}\n")

    f1.close()
    f2.close()
    f3.close()

def func21():
    from torch import multiprocessing as mp
    from external.autopunct.correct import correct
    import spacy
    import time
    import json
    from nltk.tokenize import sent_tokenize
    punct_model = spacy.load("./external/autopunct/punct-model")
    caps_model = spacy.load("./external/autopunct/caps-model")

    def punct(punct_res):
        while True:
            if task_queue.qsize() == 0:
                break
            task = task_queue.get()
            text = i_list[task]
            punct_text = correct(text, punct_model, caps_model)
            punct_res[task] = punct_text
            print(f"there are {task_queue.qsize()} left")

    with open("./data/howto100m/sampled/raw_caption.json", "r") as f:
        i_list = json.load(f)
        k_list = list(i_list.keys())
        manager = mp.Manager()
        task_queue = manager.Queue()
        for k in k_list:
            task_queue.put(k)

    punct_res = manager.dict()
    procs = []
    for ps in range(12):
        proc = mp.Process(target=punct, args=(punct_res,))
        proc.start()
        procs.append(proc)
        time.sleep(1)

    for proc in procs:
        proc.join()

    with open("./data/howto100m/sampled/raw_caption.tok.json", "w+") as f:
        print(f"{len(punct_res)} files")
        dd = []
        for k, v in punct_res.items():
            cur_d = {'id': k, "caption": sent_tokenize(v)}
            dd.append(cur_d)
        json.dump(dd, f, indent=2)

def func22():
    import json
    from nltk.tokenize import sent_tokenize

    with open("./data/howto100m/sampled/raw_caption.punct.json", "r") as f:
        d = json.load(f)

    dd = []
    for k, v in d.items():
        cur_d = {'id': k, "caption": sent_tokenize(v)}
        dd.append(cur_d)

    with open("./data/howto100m/sampled/raw_caption.punct.tok.json", "w+") as f:
        json.dump(dd, f, indent=2)

def func23():
    import json
    with open("./data/howto100m/sampled/raw_caption.tok.json", "r") as f:
        d = json.load(f)

    with open("./data/howto100m/sampled/all_in_one.json", "r") as f:
        dd = json.load(f)

    _dd = []
    for idx, i in enumerate(dd):
        if i["id"] in d:
            i["caption"] = d[i["id"]]
            _dd.append(i)

    with open("./data/howto100m/sampled/all_in_one_raw.json", "w+") as f:
        print(f"{len(_dd)} items")
        json.dump(_dd, f, indent=2)

def func231():
    import json
    import pickle
    with open("./data/howto100m/sampled/caption_filtered.p", "rb") as f:
        d = pickle.load(f)

    with open("./data/howto100m/sampled/all_in_one.json", "r") as f:
        dd = json.load(f)

    _dd = []
    for idx, i in enumerate(dd):
        if i["id"] in d:
            i["caption"] = d[i["id"]]
            _dd.append(i)

    with open("./data/howto100m/sampled/all_in_one_filtered.json", "w+") as f:
        print(f"{len(_dd)} items")
        json.dump(_dd, f, indent=2)

def func24():
    # measure the length
    import json
    from collections import Counter
    with open("./data/howto100m/sampled/all_in_one_filtered.json", "r") as f:
        d = json.load(f)
    tot = 0
    ll = Counter()
    for item in d:
        captions = item["caption"]
        # captions = captions[5:-5]
        captions = " ".join(captions)
        captions = captions.strip().split()
        l = len(captions)
        if l >= 600:
           tot += 1
           ll[l] += 1
    print(tot)
    print(min(ll.keys()))
    print(max(ll.keys()))
    print(ll)

def func25():
    # import subprocess
    import os
    import time
    import json
    from youtube_transcript_api import YouTubeTranscriptApi
    import torch.multiprocessing as mp

    def download(d):
        fail = 0
        while True:
            if task_queue.qsize() == 0:
                break
            task = task_queue.get()
            try:
                cp = YouTubeTranscriptApi.get_transcript(task)
                d[task] = cp
            except:
                fail += 1
            print(f"there are {task_queue.qsize()} left, {fail} errors already")

        print(fail)

    with open("./data/howto100m/sampled/high_rank.csv", "r") as f:
        manager = mp.Manager()
        task_queue = manager.Queue()
        for line in f:
            tks = line.strip().split(",")
            task_queue.put(tks[0])

            # if task_queue.qsize() == 20:
            #     break

    res = manager.dict()

    ps = []
    for i in range(16):
        p = mp.Process(target=download, args=(res,))
        p.start()
        time.sleep(1)
        ps.append(p)

    for p in ps:
        p.join()

    with open("./data/howto100m/sampled/raw_caption.json", "w+") as f:
        res = res.copy()
        json.dump(res, f, indent=2)

def func26():
    import torch.multiprocessing as mp
    from external.autopunct.correct import correct
    import spacy
    import time
    import json
    from nltk.tokenize import sent_tokenize
    punct_model = spacy.load("./external/autopunct/punct-model")
    caps_model = spacy.load("./external/autopunct/caps-model")

    manager = mp.Manager()
    task_queue = manager.Queue()
    tk_res = manager.dict()

    with open("./data/howto100m/sampled/raw_caption.json", "r") as f:
        d = json.load(f)
        for k, v in d.items():
            task_queue.put({"id": k, "caption": v})

            # if task_queue.qsize() >= 30:
            #     break

    def process_caption(d):
        while True:
            if task_queue.qsize() == 0:
                break
            task = task_queue.get()
            caption = task["caption"]
            caption = " ".join([x["text"] for x in caption])
            caption = correct(caption, punct_model=punct_model, caps_model=caps_model)
            tk_caption = sent_tokenize(caption)
            d[task["id"]] = tk_caption
            print(f"there are {task_queue.qsize()} left")

    ps = []
    for i in range(12):
        p = mp.Process(target=process_caption, args=(tk_res,))
        p.start()
        time.sleep(1)
        ps.append(p)

    for p in ps:
        p.join()

    with open("./data/howto100m/sampled/raw_caption.tok.json", "w+") as f:
        tk_res = tk_res.copy()
        print(f"there are {len(tk_res)} captions")
        json.dump(tk_res, f, indent=2)

def func27():
    import json
    actions = set()
    with open("./data/kinetics400/train.json", "r") as f:
        d = json.load(f)
        for k, v in d.items():
            actions.add(v["annotations"]["label"])

    with open("./data/kinetics400/validate.json", "r") as f:
        d = json.load(f)
        for k, v in d.items():
            assert v["annotations"]["label"] in actions
    print(len(actions))
    print(actions)


def func28():
    import json
    import random
    # combine data
    with open("./data/wikihow/step_goal.json", "r") as f:
        d1 = json.load(f)

    with open("./data/wikihow/step_goal.para.json", "r") as f:
        d2 = json.load(f)

    d2 = [x for x in d2 if x["caption"] != x["original_caption"]]
    print(f"{len(d1)}, {len(d2)}")

    d = d1 + d2
    random.shuffle(d)

    print(len(d))

    with open("./data/wikihow/step_goal.comb.train.json", "w", encoding="utf-8") as f:
        dd = d[:-6000]
        json.dump(dd, f, indent=2)
    with open("./data/wikihow/step_goal.comb.dev.json", "w", encoding="utf-8") as f:
        dd = d[-6000:]
        json.dump(dd, f, indent=2)


def func29():
    ww, wl, lw, ll = 0, 0, 0, 0
    f1 = open("./data/result/roberta.wikihow.aug.ep5.raw_cap.cut_head.log", "r")
    f2 = open("./data/result/roberta.wikihow.ep5.raw_cap.cut_head.log", "r")
    f1.readline()
    f1.readline()
    f2.readline()
    f2.readline()
    goal = [neighbor for neighbor in G.adj[step] if G.edges[step, neighbor]['type'] == 'step-goal-paraphrase']
    for l1, l2 in zip(f1, f2):
        r1 = l1.strip().split("\t")[-1]
        r2 = l2.strip().split("\t")[-1]
        if r1 == 'True' and r2 == 'True':
            ww += 1
        elif r1 == 'True' and r2 == 'False':
            wl += 1
        elif r1 == 'False' and r2 == 'True':
            lw += 1
        elif r1 == 'False' and r2 == 'False':
            ll += 1
    print(ww, wl, lw, ll)

def func30():
    import json
    with open("./data/wikihow/step_goal.comb.train.json", "r") as f:
        d = json.load(f)
    max_t = max_c = 0
    tot = 0
    for item in d:
        c_len = len(" ".join(item["caption"]).strip().split())
        t_len = len(item["task"].strip().split())
        max_t = max(max_t, t_len)
        max_c = max(max_c, c_len)
        if c_len > 520:
            tot +=1
    print(max_t, max_c, tot)

def func31():
    # calc task overlap between wikihow training and howto100m
    wk_tasks = set()
    with open("./data/wikihow/task_category.map", "r") as f:
        for line in f:
            task = line.strip().split(" || ")[0]
            task = task[7:] if task[:7] == "How to " else task
            wk_tasks.add(task)

    ht_tasks = set()
    with open("./data/howto100m/task_ids.csv", "r") as f:
        for line in f:
            task = line.strip().split("\t")[1]
            ht_tasks.add(task)

    tot = len(ht_tasks)
    hit = 0
    for t in ht_tasks:
        if t in wk_tasks:
            hit += 1
    print(hit, tot, hit/tot)

    return wk_tasks, ht_tasks


def func32():
    # compare
    # baseline win and our method fail.
    # The task is in the wikihow goal
    # and the steps are extended

    import ast
    import json
    from collections import defaultdict
    # find training data by the task
    with open("./data/wikihow/step_goal.train.json", "r") as f:
        db = json.load(f)
        db_map = defaultdict(list)
        for idx, item in enumerate(db):
            db_map[item["task"]].append(idx)


    with open("./data/wikihow/step_goal.comb.train.json", "r") as f:
        da = json.load(f)
        da_map = defaultdict(list)
        for idx, item in enumerate(da):
            da_map[item["task"]].append(idx)

    with open("./data/wikihow/para_step_goal_links.json", "r") as f:
        para_d = json.load(f)

    fb = open("./data/result/roberta.wikihow.ep3.filtered.log", "r")
    fa = open("./data/result/roberta.wikihow.aug.comb.ep2.filtered.log", "r")
    fout = open("./data/result/wikihow_ep3_comb_ep2_filter.tsv", "w+")
    fb.readline()
    fb.readline()
    fa.readline()
    fa.readline()

    bs_win = bs_loss = our_win = our_loss = 0
    bs_in_train = our_in_train = 0
    change_good = change_bad = no_change_good = no_change_bad = 0

    fout.write("tag\tdiff num\ttask\thowto caption\twikihow only pred\tmix pred\toriginal steps\tnew steps\tsame or diff\n")
    for lb, la in zip(fb, fa):
        tkb = lb.strip().split("\t")
        tka = la.strip().split("\t")
        if len(tkb) < 4:
            break
        task = tkb[-3]
        howto_cap = tkb[0]
        b_pred = ast.literal_eval(tkb[-2])
        a_pred = ast.literal_eval(tka[-2])
        tag = None
        if tkb[-1] == 'True' and tka[-1] == 'False':
            bs_win += 1
            if task in db_map.keys() and task in da_map.keys():
                bs_in_train += 1
                tag = 1
        elif tkb[-1] == 'False' and tka[-1] == 'True':
            our_win += 1
            if task in da_map.keys() and task in db_map.keys():
                our_in_train += 1
                tag = 2

        if tag in [1, 2]:
            # find the caption
            items = [da[x] for x in da_map[task]]
            b_cap = [x.get("original_caption", x["caption"]) for x in items]
            a_cap = [x["caption"] for x in items]
            diff = [idx for idx, (x, y) in enumerate(zip(b_cap, a_cap)) if x != y]
            diff_num = f"{len(diff)}/{len(b_cap)}"
            fout.write(f"{tag}\t{diff_num}\t{task}\t{howto_cap}\t{b_pred}\t{a_pred}\n")
            for idx in range(len(b_cap)):
                fout.write(f" \t \t \t \t \t \t{b_cap[idx]}\t{a_cap[idx]}\t{b_cap[idx] == a_cap[idx]}\t")
                if idx in diff:
                    for sent in b_cap[idx]:
                        cur_para = para_d[sent]["retrieved_goals"]
                        if len(cur_para) != 0:
                            fout.write(f"{sent}->{' | '.join(cur_para)}\t")
                fout.write("\n")

            if tag == 1:
                # expand and harm
                if len(diff) != 0:
                    change_bad += 1
                elif len(diff) == 0:
                    no_change_bad += 1
            if tag == 2:
                # expand and benefit
                if len(diff) != 0:
                    change_good += 1
                elif len(diff) == 0:
                    no_change_good += 1

    print(bs_win, our_win, bs_in_train, our_in_train, change_good, change_bad, no_change_good, no_change_bad)
    fa.close()
    fb.close()
    fout.close()

def func33():
    import json
    with open("./data/wikihow/step_goal.json", "r") as f:
        d = json.load(f)
    tasks = set()
    for item in d:
        tasks.add(item["task"])
    print(len(tasks))

def func34():
    import pickle
    import json
    from collections import defaultdict

    with open("./data/coin/coin_goals.p", "rb") as f:
        goals = pickle.load(f)

    with open("./data/wikihow/step_goal.para.json", "r") as f:
        d = json.load(f)

    t2i = defaultdict(list)
    for idx, item in enumerate(d):
        t2i[item["task"]].append(idx)

    dd = {}
    diff = 0
    tot = 0
    for g in goals:
        g = g[7:] if g[:7] == "How to " else g
        if g in t2i:
            idx_list = t2i[g]
            info = [d[x] for x in idx_list]
            dd[g] = info
            for iinfo in info:
                if iinfo['caption'] != iinfo['original_caption']:
                    diff += 1
                tot += 1
        else:
            print(f"skip {g}")

    print(len(dd), tot, diff)
    with open("./data/coin/coin_goal_step.json", "w") as f:
        json.dump(dd, f, indent=2)

def func35():
    import pickle
    import gzip
    import glob
    import os
    import json
    import tqdm

    smap = {}
    for fname in tqdm.tqdm(glob.glob("./data/howto100m/induced_steps/*.p.gz")):
        with gzip.open(fname, "rb") as f:
            id = os.path.basename(fname).split(".")[0]
            cur_steps = pickle.load(f)
            smap[id] = []
            for k, v in cur_steps.items():
                assert isinstance(v[3], list)
                cur_clip_ranked_steps = [xx[0] for xx in v[3]]
                smap[id].append(cur_clip_ranked_steps)

    with open("./data/howto100m/sampled/induced_steps.1k.json", "w") as f:
        json.dump(smap, f, indent=2)
    # match with all_in_one

def func36():
    import json

    with open("./data/howto100m/sampled/induced_steps.1k.json", "r") as f:
        step_d = json.load(f)
        for k, v in step_d.items():
            _v = []
            for vv in v:
                if not _v or vv[0] not in _v:
                    _v.append(vv[0])
            step_d[k] = _v
        # step_d = {k: [vv[0] for vv in v] for k, v in step_d.items()}

    small_d = []
    comp = []
    with open("./data/howto100m/sampled/all_in_one_filtered.json", "r") as f:
        d = json.load(f)
        for item in d:
            vid = item["id"]
            if vid in step_d:
                old_cap = item["caption"]
                item["caption"] = step_d[vid]
                small_d.append(item)
                comp.append(f"https://www.youtube.com/watch?v={vid}\t{' '.join(old_cap)}\t{' '.join(item['caption'])}")

    with open("./data/howto100m/sampled/all_in_one_1k_wk_step.json", "w+") as f:
        print(len(small_d))
        json.dump(small_d, f, indent=2)

    with open("./data/howto100m/sampled/caption_steps.compare.tsv", "w") as f:
        for l in comp:
            f.write(l + "\n")



if __name__ == "__main__":
    # func36()
    func9()
    # func13()
    # func10("./data/wikihow/step_goal.para.json")
    # func22()
    # func30()