import os
import sys
os.chdir("/home/[USER]/workshop/wikihow")
sys.path.append("/home/[USER]/workshop/wikihow")
import json
import matplotlib.pyplot as plt
import numpy as np
from ir_eval import calc_average_precision, calc_mean_rank, calc_mrr, calc_mean_average_precision
from collections import defaultdict, Counter

def compare_win_loss():
    with open("./data/howto100m/step_goal.para.d0.sample_base.howto1k.all.ir.results", "r") as f:
        d_base = json.load(f)
    with open("./data/howto100m/step_goal.para.d1.sample_base.sample_expansion.rerank.train_null00.50.howto1k.all.ir.results", "r") as f:
        d_compare = json.load(f)

    c_win = 0
    b_win = 0
    for query, res in d_base.items():
        for gold, base_rank in res['gold'].items():
            if query in d_compare:
                c_rank = d_compare[query]['gold'][gold]
                if c_rank < base_rank:
                    print("[1]", query, base_rank, c_rank, gold)
                    c_win += 1
                elif base_rank < c_rank:
                    b_win += 1
                    print("[2]", query, base_rank, c_rank, gold)
            else:
                print(f"error {query} not found")

    print(b_win, c_win)

def draw_video_leng():
    folder = "./data/howto100m/video1k/t150_resplit"
    base_result = f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_dev.0.6.None.json"
    exp_result = f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_dev.0.5.None.json"
    MAX_RANK = 200
    with open(base_result,"r") as f:
        d_base = json.load(f)
    with open(exp_result, "r") as f:
        d_exp = json.load(f)

    win_vid = []
    loss_vid = []
    for goal, rank_info in d_base.items():
        for gold, base_rank in rank_info['gold'].items():
            exp_rank = d_exp[goal]['gold'][gold]
            if base_rank > exp_rank:
                win_vid.append(gold)
            elif base_rank < exp_rank:
                loss_vid.append(gold)
            # if gold_rank == MAX_RANK + 1 and d_exp[goal]['gold'][gold] != MAX_RANK + 1:
            #     win_vid.append(gold)
            # if gold_rank != MAX_RANK + 1 and d_exp[goal]['gold'][gold] == MAX_RANK + 1:
            #     loss_vid.append(gold)
    print(f"find win: {len(win_vid)}, loss: {len(loss_vid)}")

    with open("./data/howto100m/video_length.json", "r") as f:
        vl_map = json.load(f)
    print(max(vl_map.values()))

    win_len = []
    loss_len = []
    for vid in win_vid:
        win_len.append(vl_map[vid])
    for vid in loss_vid:
        loss_len.append(vl_map[vid])

    fig, ax = plt.subplots(nrows=1, ncols=1)
    interval = 60
    bins = np.arange(0, 2000, interval)
    hist1, _ = np.histogram(win_len, bins=bins)
    hist2, _ = np.histogram(loss_len, bins=bins)
    print(hist1)
    print(hist2)
    xs = np.arange(bins.shape[0] - 1)
    ax.bar(xs, hist1, alpha=0.5, color='b', label='win')
    ax.bar(xs, hist2, alpha=0.5, color='r', label='loss')
    plt.legend()
    plt.show()

def print_top1_win():
    with open("./data/howto100m/video_length.json", "r") as f:
        video_len_map = json.load(f)
    base_result = f"./data/howto100m/video1k/ir_results/step_goal.para.d1.sample_base.sample_expansion.rerank.train_null00.50.howto1k.all1.0.0.1.0.0.0.ir.results"
    exp_result = f"./data/howto100m/video1k/ir_results/step_goal.para.d1.sample_base.sample_expansion.rerank.train_null00.50.howto1k.all1.0.0.1.0.1.0.ir.results"
    with open(base_result, "r") as f:
        d_base = json.load(f)
    with open(exp_result, "r") as f:
        d_exp = json.load(f)

    print(len(d_base), len(d_exp))
    tar_rank = 100
    for min_len in [600]:
        win_vid = []
        loss_vid = []
        win_goal = set()
        loss_goal = set()
        tot = 0
        base_mr = 0
        exp_mr = 0
        for goal, base_rank_info in d_base.items():
            exp_rank_info = d_exp[goal]
            base_small_rank = [base_rank_info['gold'][k] for k in base_rank_info['gold'] if video_len_map[k] <= min_len]
            exp_small_rank = [exp_rank_info['gold'][k] for k in exp_rank_info['gold'] if video_len_map[k] <= min_len]
            assert len(base_small_rank) == len(exp_small_rank)
            for gold, base_rank in base_rank_info['gold'].items():
                exp_rank = exp_rank_info['gold'][gold]
                # real base rank if the other videos don't exist
                base_rank -= sum([int(x < base_rank) for x in base_small_rank]) if base_rank != 101 else 0
                exp_rank -= sum([int(x < exp_rank) for x in exp_small_rank]) if exp_rank != 101 else 0
                if video_len_map[gold] <= min_len:
                    continue
                # if base_rank - exp_rank > 20:
                #     print(goal, gold, exp_rank, base_rank)
                tot += 1
                if base_rank > tar_rank and exp_rank <= tar_rank:
                    win_vid.append(gold)
                    win_goal.add(goal)
                    print(f"[win] {goal} https://www.youtube.com/watch?v={gold} {base_rank} {exp_rank}")
                if base_rank <= tar_rank and exp_rank > tar_rank:
                    loss_vid.append(gold)
                    loss_goal.add(goal)
                    print(f"[loss] {goal} https://www.youtube.com/watch?v={gold} {base_rank} {exp_rank}")
        print(f"min: {min_len}, total: {tot}, find win: {len(win_vid)}/{len(win_goal)}, "
              f"loss: {len(loss_vid)}/{len(loss_goal)}")

        print(win_goal)
        print(loss_goal)

def high_rank():
    with open("./data/howto100m/video1k/retrieved.filter.base.v5.json", "r") as f:
        dbase = json.load(f)
    with open("./data/howto100m/video1k/retrieved.filter.base.exp.v5.json", "r") as f:
        dexp = json.load(f)

    win = 0
    loss = 0
    max_rank = 50
    for goal, base_info in dbase.items():
        exp_info = dexp[goal]
        base_rank = base_info['gold']
        exp_rank = exp_info['gold']
        for vid, base_vrank in base_rank.items():
            exp_vrank = exp_rank[vid]
            if  base_vrank <= max_rank and exp_vrank > max_rank:
                loss += 1
            elif base_vrank > max_rank and exp_vrank <= max_rank:
                win += 1
    print(win, loss)


def filter_good_video():
    vid_rank_map = {}
    max_r = 0
    with open("./data/howto100m/HowTo100M_v1.csv", "r") as f:
        f.readline()
        for line in f:
            tks = line.strip().split(",")
            vid_rank_map[tks[0]] = int(tks[-2])
            max_r = max(max_r, int(tks[-2]))
    print(max_r)

    filter_split = {}
    with open("./data/howto100m/video1k/split_task_video_map.json", "r") as f:
        split = json.load(f)

    for s in ['train', 'dev', 'test']:
        tot = 0
        cur_split = split[s]
        filter_split[s] = {}
        for k, v in cur_split.items():
            filter_v = [x for x in v if vid_rank_map[x] <= 50]
            filter_split[s][k] = filter_v
            tot += len(filter_v)
        print(s, tot)

    with open("./data/howto100m/video1k/v50_split_task_video_map.json", "w+") as f:
        json.dump(filter_split, f, indent=2)


def look_rank():
    with open("./data/howto100m/video1k/retrieved.goal_only.json", "r") as f:
        d0 = json.load(f)
    with open("./data/howto100m/video1k/retrieved.all_step.json", "r") as f:
        d1 = json.load(f)
    with open("./data/howto100m/video1k/retrieved.filter.base.mean_rank.json", "r") as f:
        dbase = json.load(f)
    with open("./data/howto100m/video1k/retrieved.filter.base.exp.mean_rank.json", "r") as f:
        dexp = json.load(f)

    rank = []
    for k, v in dbase.items():
        rank += list(v['gold'].values())
    print(sum(rank) / len(rank), sum(rank), len(rank))
    print(np.median(rank))


    rank = []
    for k, v in dexp.items():
        rank += list(v['gold'].values())
    print(sum(rank) / len(rank), sum(rank), len(rank))
    print(np.median(rank))


    rank = []
    for k, v in d0.items():
        rank += list(v['gold'].values())
    print(sum(rank) / len(rank), sum(rank), len(rank))
    print(np.median(rank))

    rank = []
    for k, v in d1.items():
        rank += list(v['gold'].values())
    print(sum(rank) / len(rank), sum(rank), len(rank))
    print(np.median(rank))


def recalc_rerank():
    with open("./data/howto100m/video1k/retrieved.goal_only.json", "r") as f:
        d0 = json.load(f)
    with open("./data/howto100m/video1k/retrieved.all_step.json", "r") as f:
        d1 = json.load(f)
    with open("./data/howto100m/video1k/retrieved.filter.base.mean_rank.json", "r") as f:
        dbase = json.load(f)
    with open("./data/howto100m/video1k/retrieved.filter.base.exp.mean_rank.json", "r") as f:
        dexp = json.load(f)

    assert len(d0) == len(d1) and len(d1) == len(dbase) and len(dbase) == len(dexp)

    tn = [1, 3, 5, 10, 15, 20, 25, 30]
    recall = {x: 0 for x in tn}
    for k, v in d0.items():
        pred_rank = list(v['gold'].values())
        for k in recall.keys():
            if any([x <= k for x in pred_rank]):
                recall[k] += 1
    print({k: v / len(d0) for k, v in recall.items()})

    recall = {x: 0 for x in tn}
    for k, v in d1.items():
        pred_rank = list(v['gold'].values())
        for k in recall.keys():
            if any([x <= k for x in pred_rank]):
                recall[k] += 1
    print({k: v / len(d0) for k, v in recall.items()})

    recall = {x: 0 for x in tn}
    for k, v in dbase.items():
        pred_rank = list(v['gold'].values())
        for k in recall.keys():
            if any([x <= k for x in pred_rank]):
                recall[k] += 1
    print({k: v / len(d0) for k, v in recall.items()})

    recall = {x: 0 for x in tn}
    for k, v in dexp.items():
        pred_rank = list(v['gold'].values())
        for k in recall.keys():
            if any([x <= k for x in pred_rank]):
                recall[k] += 1
    print({k: v / len(d0) for k, v in recall.items()})

def calc_win_loss():
    with open("./data/howto100m/video1k/t150/retrieved.filter.base.mean_rank.video_1k_t150_dev.json", "r") as f:
        dbase = json.load(f)
    with open("./data/howto100m/video1k/t150/retrieved.filter.base.exp.mean_rank.video_1k_t150_dev.json", "r") as f:
        dexp = json.load(f)

    win = 0
    loss = 0
    same = 0
    for k, v in dbase.items():
        base_pred = v['gold']
        for vid, base_rank in base_pred.items():
            exp_rank = dexp[k]['gold'][vid]
            if exp_rank < base_rank:
                win += 1
            elif base_rank < exp_rank:
                loss += 1
            else:
                same += 1
    print(win, loss, same)

    win = 0
    loss = 0
    gap = []
    for k, v in dbase.items():
        base_pred = v['gold']
        exp_pred = dexp[k]['gold']
        base_mr = np.mean(list(base_pred.values()))
        exp_mr = np.mean(list(exp_pred.values()))
        if base_mr < exp_mr:
            loss += 1
        elif base_mr > exp_mr:
            win += 1
        gap.append(base_mr - exp_mr)
        # print(base_mr, exp_mr, base_mr - exp_mr)
    print(win, loss)

    fig, ax = plt.subplots(nrows=1, ncols=1)
    interval = 1
    bins = np.arange(-20, 20, interval)
    hist1, _ = np.histogram(gap, bins=bins)
    print(hist1)
    print(_)
    print(sum(hist1[:20]), sum(hist1[21:]))
    xs = np.arange(bins.shape[0] - 1)
    ax.bar(xs, hist1, alpha=0.5, color='b')
    plt.show()


def oracle():
    with open("./data/howto100m/video1k/t150/retrieved.filter.base.mean_rank.video_1k_t150_dev.json", "r") as f:
        dbase = json.load(f)
    with open("./data/howto100m/video1k/t150/retrieved.filter.base.exp.mean_rank.video_1k_t150_dev.json", "r") as f:
        dexp = json.load(f)

    tn = [1, 3, 5, 10, 15, 25, 50]
    recall = {x: 0 for x in tn}
    oracle = {}
    for k, v in dbase.items():
        base_pred = v['gold']
        exp_pred = dexp[k]['gold']
        oracle[k] = {'gold': {}}
        for kk, base_idx in base_pred.items():
            exp_idx = exp_pred[kk]
            oracle_idx = min(base_idx, exp_idx)
            oracle[k]['gold'][kk] = oracle_idx

        for n, r in recall.items():
            cur_hit = sum([x <= n for x in oracle[k]['gold'].values()])
            recall[n] += cur_hit / len(oracle[k]['gold'])
    recall = {k: v / len(dbase) for k, v in recall.items()}
    print(recall)

def compare_large_gap():
    folder = "./data/howto100m/video1k/t150_resplit"
    with open(f"{folder}/all_goal_useful_step.mr.base.m15.0.json", "r") as f:
        base_q = json.load(f)
    with open(f"{folder}/all_goal_useful_step.goal.c1.train_null.mr.exp.m15.0.json", "r") as f:
        exp_q = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_dev.0.6.None.json", "r") as f:
        base_d = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.goal.c1.train_null.mr.exp.0.video_1k_t150_resplit_dev.0.5.0.json", "r") as f:
        exp_d = json.load(f)
    base_info = base_q['info']
    exp_info = exp_q['info']
    win = 0
    loss = 0
    test_win = 0
    test_loss = 0
    for k, v in base_info.items():
        base_mr = v[1]
        exp_mr = exp_info[k][1]
        if base_mr < exp_mr:
            loss += 1
        elif exp_mr < base_mr:
            win += 1

        if base_mr - exp_mr > 0:
            print("\n===============================================")
            print(k)
            print("-----------")
            exp_query = []
            for x in exp_q['step_map'][k]:
                if x not in base_q['step_map'][k]:
                    exp_query.append(['+'] + x)
            for x in base_q['step_map'][k]:
                if x not in exp_q['step_map'][k]:
                    exp_query.append(['-'] + x)
            for x in base_q['step_map'][k]:
                print(', '.join(x))
            print("-----------")
            for x in exp_query:
                print(', '.join(x))
            print("-----------")
            base_pred = np.mean(list(base_d[k]['gold'].values()))
            exp_pred = np.mean(list(exp_d[k]['gold'].values()))
            print(f"train: {base_mr :.2f}, {exp_mr :.2f}, {base_mr - exp_mr :.2f}")
            print(f"test: {base_pred :.2f}, {exp_pred :.2f}, {base_pred - exp_pred :.2f}")
            if base_pred < exp_pred:
                test_loss += 1
            elif exp_pred < base_pred:
                test_win += 1

    print(win, loss)
    print(test_win, test_loss)

def when_win():
    folder = "./data/howto100m/video1k/t150_resplit"
    with open(f"{folder}/all_goal_useful_step.mr.base.m15.json", "r") as f:
        base_q = json.load(f)
    with open(f"{folder}/all_goal_useful_step.mr.exp.m15.json", "r") as f:
        exp_q = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_dev.0.6.None.json", "r") as f:
        base_d = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_dev.0.5.None.json", "r") as f:
        exp_d = json.load(f)

    base_info = base_q['info']
    exp_info = exp_q['info']
    win_gap = []
    loss_gap = []
    for k, v in base_info.items():
        base_mr = v[1]
        exp_mr = exp_info[k][1]
        base_pred = np.mean(list(base_d[k]['gold'].values()))
        exp_pred = np.mean(list(exp_d[k]['gold'].values()))
        if base_pred < exp_pred:
            loss_gap.append(base_mr - exp_mr)
        elif base_pred > exp_pred:
            win_gap.append(base_mr - exp_mr)

    fig, ax = plt.subplots(nrows=1, ncols=1)
    interval = 1
    bins = np.arange(-20, 20, interval)
    hist1, _ = np.histogram(win_gap, bins=bins)
    hist2, _ = np.histogram(loss_gap, bins=bins)
    xs = np.arange(bins.shape[0] - 1)
    ax.bar(xs, hist1, alpha=0.5, color='b', label='win')
    ax.bar(xs, hist2, alpha=0.5, color='r', label='loss')
    plt.legend()
    plt.show()

def oracle_3():
    folder = "./data/howto100m/video1k/t150_resplit"
    with open(f"{folder}/all_goal_useful_step.mr.base.m15.json", "r") as f:
        base_q = json.load(f)
    with open(f"{folder}/all_goal_useful_step.mr.exp.m15.json", "r") as f:
        exp_q = json.load(f)

    base_info = base_q['info']
    exp_info = exp_q['info']
    with open(f"{folder}/retrieved.goal_only.video_1k_t150_resplit_train.json", "r") as f:
        dgoal = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_train.1.0.json", "r") as f:
        dbase = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_train.1.0.json", "r") as f:
        dexp = json.load(f)
    # with open(f"{folder}/retrieved.goal_only.video_1k_t150_resplit_dev.json", "r") as f:
    #     dgoal = json.load(f)
    # with open(f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_dev.0.6.None.json", "r") as f:
    #     dbase = json.load(f)
    # with open(f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_dev.0.5.None.json", "r") as f:
    #     dexp = json.load(f)

    tn = [1, 3, 5, 10, 15, 25, 50]
    recall = {x: 0 for x in tn}
    oracle = {}
    th = 0.5
    select_model = {}
    win_counter = Counter()
    dd = 0
    w = 0
    l = 0
    for k, v in dbase.items():
        base_gap = base_info[k][0] - base_info[k][1]
        exp_gap = exp_info[k][0] - exp_info[k][1]
        goal_pred = dgoal[k]['gold']
        base_pred = v['gold']
        exp_pred = dexp[k]['gold']
        win = {0:0, 1:0, 2:0}
        ks = list(goal_pred.keys())
        print(k)
        print([f"{kk}: {goal_pred[kk] :<4}" for kk in ks])
        print([f"{kk}: {base_pred[kk] :<4}" for kk in ks])
        print([f"{kk}: {exp_pred[kk] :<4}" for kk in ks])
        print(np.mean([goal_pred[kk] for kk in ks]), np.mean([base_pred[kk] for kk in ks]), np.mean([exp_pred[kk] for kk in ks]))

        for kk, vv in goal_pred.items():
            if vv < base_pred[kk] and vv < exp_pred[kk]:
                win[0] += 1
            elif base_pred[kk] < vv and base_pred[kk] < exp_pred[kk]:
                win[1] += 1
            elif exp_pred[kk] < vv and exp_pred[kk] < base_pred[kk]:
                win[2] += 1

            if base_pred[kk] == 201 and exp_pred[kk] != 201:
                w += 1
            elif exp_pred[kk] == 201 and base_pred[kk] != 201:
                l += 1

        index = np.argmax([win[i] for i in range(3)])
        # index = 0
        print(win, index)
        if win[0] == win[1] and win[1] == win[2]:
            dd += 1
        select_model[k] = index
        win_counter[index] += 1
    print(win_counter, dd)
    print(w, l)

        # oracle[k] = {'gold': {}}
        # goal_mr = np.mean(list(goal_pred.values()))
        # base_mr = np.mean(list(base_pred.values()))
        # exp_mr = np.mean(list(exp_pred.values()))
        # index = int(np.argmin([goal_mr, base_mr, exp_mr]))
        # goal_ap = calc_average_precision(dgoal[k], 200)
        # base_ap = calc_average_precision(dbase[k], 200)
        # exp_ap = calc_average_precision(dexp[k], 200)
        # index = int(np.argmax([goal_ap, base_ap, exp_ap]))
        # index = int(np.argmin([base_mr, exp_mr]))

    with open(f"{folder}/retrieved.goal_only.video_1k_t150_resplit_dev.json", "r") as f:
        dgoal_test = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_dev.0.6.None.json", "r") as f:
        dbase_test = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_dev.0.5.None.json", "r") as f:
        dexp_test = json.load(f)


    w = 0
    l = 0
    for k, v in dbase_test.items():
        goal_pred = dgoal_test[k]['gold']
        base_pred = v['gold']
        exp_pred = dexp_test[k]['gold']
        oracle[k] = {'gold': {}}
        for kk, base_idx in base_pred.items():
            goal_idx = goal_pred[kk]
            exp_idx = exp_pred[kk]
            # print(select_model[k])
            oracle_idx = [goal_idx, base_idx, exp_idx][select_model[k]]
            # oracle_idx = [base_idx, exp_idx][select_model[k]]
            # oracle_idx = min(base_idx, exp_idx, goal_idx)
            oracle[k]['gold'][kk] = oracle_idx
            if base_idx == 201 and exp_idx != 201:
                w += 1
            elif exp_idx == 201 and base_idx != 201:
                l += 1

        # oracle_idx = min(base_idx, exp_idx, goal_idx)
        # for kk, base_idx in base_pred.items():
        #     goal_idx = goal_pred[kk]
        #     exp_idx = exp_pred[kk]
            # oracle_idx = min(base_idx, exp_idx, goal_idx)
            # oracle[k]['gold'][kk] = oracle_idx
            # if exp_gap < base_gap:
            #     if base_gap <= th:
            #         oracle[k]['gold'][kk] = goal_idx
            #     else:
            #         oracle[k]['gold'][kk] = base_idx
            # else:
            #     if exp_gap <= th:
            #         oracle[k]['gold'][kk] = goal_idx
            #     else:
            #         oracle[k]['gold'][kk] = exp_idx

        for n, r in recall.items():
            cur_hit = sum([x <= n for x in oracle[k]['gold'].values()])
            recall[n] += cur_hit / len(oracle[k]['gold'])

    recall = {k: v / len(dbase_test) for k, v in recall.items()}
    print(recall)
    print(w, l)


def leng_rank():
    fig, ax = plt.subplots(nrows=1, ncols=2)
    for sp in ['dev', 'test']:
        print(f"============{sp}============")
        folder = "./data/howto100m/video1k/t150_resplit"
        if sp in ['dev', 'test']:
            goal_result = f"{folder}/retrieved.goal_only.video_1k_t150_resplit_{sp}.json"
            alls_result = f"{folder}/retrieved.all_step.video_1k_t150_resplit_{sp}.0.1.json"
            base_result = f"{folder}/retrieved.filter.base.exp.mr.base.0.video_1k_t150_resplit_{sp}.0.6.0.json"
            exp_result = f"{folder}/retrieved.filter.base.exp.mr.exp.0.video_1k_t150_resplit_{sp}.0.5.0.json"
            # goal_result = f"{folder}/retrieved.goal_only.video_1k_t150_resplit_{sp}.json"
            # alls_result = f"{folder}/retrieved.all_step.video_1k_t150_resplit_{sp}.0.1.json"
            # base_result = f"{folder}/retrieved.filter.base.exp.ap.base.video_1k_t150_resplit_{sp}.0.6.None.json"
            # exp_result = f"{folder}/retrieved.filter.base.exp.ap.exp.video_1k_t150_resplit_{sp}.0.4.None.json"
        elif sp in ['train']:
            goal_result = f"{folder}/retrieved.goal_only.video_1k_t150_resplit_{sp}.json"
            alls_result = f"{folder}/retrieved.goal_only.video_1k_t150_resplit_{sp}.json"
            base_result = f"{folder}/retrieved.filter.base.exp.mr.base.video_1k_t150_resplit_{sp}.1.0.json"
            exp_result = f"{folder}/retrieved.filter.base.exp.mr.exp.video_1k_t150_resplit_{sp}.1.0.json"
        else:
            raise NotImplementedError

        MAX_RANK = 200
        with open(goal_result, "r") as f:
            d_goal = json.load(f)
        with open(alls_result, "r") as f:
            d_all = json.load(f)
        with open(base_result,"r") as f:
            d_base = json.load(f)
        with open(exp_result, "r") as f:
            d_exp = json.load(f)

        with open("./data/howto100m/video_length.json", "r") as f:
            vl_map = json.load(f)
        print(max(vl_map.values()))

        goal_rank_list = defaultdict(list)
        all_rank_list = defaultdict(list)
        base_rank_list = defaultdict(list)
        exp_rank_list = defaultdict(list)

        for goal, rank_info in d_base.items():
            for gold, base_rank in rank_info['gold'].items():
                goal_rank = d_goal[goal]['gold'][gold]
                all_rank = d_all[goal]['gold'][gold]
                exp_rank = d_exp[goal]['gold'][gold]
                v_len = int(vl_map[gold]) // 60 // 3
                goal_rank_list[v_len].append(goal_rank)
                all_rank_list[v_len].append(all_rank)
                base_rank_list[v_len].append(base_rank)
                exp_rank_list[v_len].append(exp_rank)

        ll = [goal_rank_list, all_rank_list, base_rank_list, exp_rank_list]
        all_mi = list(range(len(goal_rank_list)))
        res = [[] for _ in ll]
        for idx, rl in enumerate(ll):
            for mi in all_mi:
                res[idx].append(np.mean(rl[mi]))
                print(idx, len(rl[mi]), np.mean(rl[mi]))

        tmp = []
        for k, v in base_rank_list.items():
            tmp += v
        print(len(tmp), np.mean(tmp))
        tmp = []
        for k, v in exp_rank_list.items():
            tmp += v
        print(len(tmp), np.mean(tmp))

        xx = np.arange(0, len(all_mi), 1)
        label = ['L0', 'L1', 'Fil-L1', 'Fil-L2']
        color = ['g', 'y', 'r', 'b']
        for idx, (r, lb, cl) in enumerate(zip(res, label, color)):
            if idx in [0, 1]:
                continue
            if sp == 'dev':
                t = 'o'
                ai = 0
            else:
                t = 'x'
                ai = 1
            ax[ai].plot(xx, r, t, label=f"{lb}-{sp}", color=cl)
    plt.legend()
    plt.show()

    # fig, ax = plt.subplots(nrows=1, ncols=1)
    # interval = 60
    # bins = np.arange(0, 2000, interval)
    # hist1, _ = np.histogram(win_len, bins=bins)
    # hist2, _ = np.histogram(loss_len, bins=bins)
    # print(hist1)
    # print(hist2)
    # xs = np.arange(bins.shape[0] - 1)
    # ax.bar(xs, hist1, alpha=0.5, color='b', label='win')
    # ax.bar(xs, hist2, alpha=0.5, color='r', label='loss')
    # plt.legend()
    # plt.show()

def tmp():
    a = """+, Take apart the vacuum cleaner, [org]
+, Empty the canister, [org]
+, Cut hair out of the beater bar, [org]
+, Dust the inside of the hose, [org]
+, Dust the inside of your vacuum to get rid of any other debris, [exp]
+, Put the roller back inside your vacuum, [exp]
+, Try running your vacuum to see if it works again, [exp]
+, Unscrew the hose from your vacuum, [exp]
+, Locate and remove the filter between the bag and the main suction tube, [exp]
+, Unscrew the lower hose from the vacuum body, [exp]"""
    a = a.replace("+, ", "").replace(", [org]", "").replace(", [exp]", "")
    print(" || ".join(a.split("\n")))

if __name__ == '__main__':
    # compare_win_loss()
    # draw_video_leng()
    # print_top1_win()
    # look_rank()
    # recalc_rerank()
    # calc_win_loss()
    # oracle()
    # compare_large_gap()
    # oracle_3()
    # leng_rank()
    compare_large_gap()