import os
import sys
sys.path.append("/home/[USER]/workshop/wikihow")
os.chdir("/home/[USER]/workshop/wikihow")
import json
from collections import Counter, defaultdict
import torch
import random
import pickle

def split_para_to_multiple():
    import pickle
    with open("./data/wikihow/para_base_all_base_0.0_30_links.json", "r") as f:
        d = json.load(f)
    complete_d = []

    # find
    _d = {}
    i = 0
    for k, v in d.items():
        _d[k] = v
        if len(_d) >= 140000:
            print(len(_d))
            with open(f"./data/wikihow/para_base_all_base_0.0_30_links_{i}.pkl", "wb") as fw:
                pickle.dump(_d, fw)
            _d = {}
            i += 1

    print(len(_d))
    with open(f"./data/wikihow/para_base_all_base_0.0_30_links_{i}.pkl", "wb") as fw:
        pickle.dump(_d, fw)
    _d = {}
    i += 1

def reset_top_dup_step_score():
    dup = Counter()
    with open("./data/wikihow/step_goal.json", "r") as f:
        d = json.load(f)
    for v in d:
        for s in v['caption']:
            s = s if not s.endswith(".") else s[:-1].strip()
            dup[s] += 1
    tot = 0
    xx = []
    for k, v in dup.items():
        if v > 1:
            tot += v
            xx.append(k)
    print(tot, sum(dup.values()))
    print(xx[:10])

    dup = sorted(dup.items(), key=lambda  x: x[1], reverse=True)
    # for idx, x in enumerate(dup):
    #     if x[1] < 5:
    #         print(idx)
    #         print(x)
    #         break

    with open("./data/wikihow/step_candidate_goal.v2.score.json", "r") as f:
        dd = json.load(f)
        for step_info in dup[:10700]:
            step = step_info[0]
            try:
                for linked_goal in dd[step]:
                    dd[step][linked_goal]['rerank_score_bck'] = dd[step][linked_goal]['rerank_score']
                    dd[step][linked_goal]['rerank_score'] = 0
            except KeyError:
                print(step)

    with open("./data/wikihow/step_candidate_goal.v2.score.json", "w+") as f:
        json.dump(dd, f, indent=2)


def create_score_file(prefix):
    with open("./data/wikihow/step2goal.json", "r") as f:
        step2goal = json.load(f)

    rerank_d = []
    for i in range(10):
        # if i in [5, 7]:
        #     continue
        # convert para data to data augmentation
        with open(f"./data/wikihow/{prefix}.{i}.result", "r") as f:
            cur_d = json.load(f)
            rerank_d += cur_d
    print(len(rerank_d))

    miss = set()
    for item in rerank_d:
        item['gold'] = ""
        item['goal'] = step2goal.get(item['step'], "goal")
        if item['goal'] == 'goal':
            miss.add(item['step'])
    print(len(miss))
    print(miss)

    #
    # with open(f"./data/wikihow/all.org.t30.test.bert.goal.t30.result", "r") as f:
    #     rerank_d = json.load(f)
    # print(len(rerank_d))
    #
    # with open("./data/wikihow/para_base_all_base_0.0_30_links.json", "r") as f:
    #     para_d = json.load(f)
    # print(len(para_d))

    # normalize score
    for item in rerank_d:
        pred = [[k, v] for k, v in item['pred'].items() if k != 'x y z']
        k_list = [x[0] for x in pred]
        v_list = torch.tensor([x[1] for x in pred])
        v_list = torch.softmax(v_list, dim=0).tolist()
        item['pred'] = {k: v for k, v in zip(k_list, v_list)}
        item['pred'] = dict(sorted(item['pred'].items(), key=lambda x: x[1], reverse=True))
        if item['step'].lower() in ['finished', 'finish', 'done', 'serve']:
            item['pred'] = {k: 0 for k in item['pred']}

    with open(f"./data/wikihow/{prefix}.all.result", "w+") as f:
        json.dump(rerank_d, f, indent=2)

    # _d = {}
    # for item in rerank_d:
    #     s = item['step']
    #     _d[s] = {}
    #     for k, v in item['pred'].items():
    #         _d[s][k] = {'rerank_score': float(v)}
    #         idx = para_d[s]['retrieved_goals'].index(k)
    #         para_score = para_d[s]['retrieved_goals_similarity'][idx]
    #         _d[s][k]['para_score'] = float(para_score)
    # with open("./data/wikihow/step_candidate_goal.v2.score.json", "w+", encoding="utf-8") as f:
    #     json.dump(_d, f, indent=2)

def all_process(data, tag, para_file):
    # from .result to score.json
    if 'pkl' in para_file:
        with open(f"./data/wikihow/{para_file}", "rb") as f:
            para_d = pickle.load(f)
    else:
        with open(f"./data/wikihow/{para_file}", "r") as f:
            para_d = json.load(f)

    with open(f"./data/wikihow/{data}.{tag}.result", "r") as f:
        rerank_d = json.load(f)

    _d = {}
    null_num = 0
    for item in rerank_d:
        s = item['step']
        _d[s] = {}
        if 'train_null' in tag:
            is_null = item['pred']['[unused2]'] == max(item['pred'].values())
        else:
            is_null = False
        null_num += int(is_null)
        for k, v in item['pred'].items():
            if k == '[unused2]':
                continue
            if not is_null:
                _d[s][k] = {'rerank_score': float(v)}
            else:
                _d[s][k] = {'rerank_score': 0}
            idx = para_d[s]['retrieved_goals'].index(k)
            para_score = para_d[s]['retrieved_goals_similarity'][idx]
            _d[s][k]['para_score'] = float(para_score)

    print(len(_d), null_num)
    with open(f"./data/wikihow/{data}.{tag}.score.json", "w+", encoding="utf-8") as f:
        json.dump(_d, f, indent=2)

def video2k_process(tag):
    with open("./data/wikihow/para_base_all_base_0.0_30_links.json", "r") as f:
        para_d = json.load(f)

    with open("./data/howto100m/video2k/task_list.json") as f:
        d = json.load(f)
        task_id = list(d.values())

    with open(f"./data/wikihow/all.org.t30.test.deterta.goal.t30.{tag}.all.result", "r") as f:
        rerank_d = json.load(f)

    _d = {}
    null_num = 0
    for item in rerank_d:
        if item['goal'] not in task_id:
            continue
        s = item['step']
        _d[s] = {}
        if tag == 'train_null':
            is_null = item['pred']['[unused2]'] == max(item['pred'].values())
        else:
            is_null = False
        null_num += int(is_null)
        for k, v in item['pred'].items():
            if k == '[unused2]':
                continue
            if not is_null:
                _d[s][k] = {'rerank_score': float(v)}
            else:
                _d[s][k] = {'rerank_score': 0}
            idx = para_d[s]['retrieved_goals'].index(k)
            para_score = para_d[s]['retrieved_goals_similarity'][idx]
            _d[s][k]['para_score'] = float(para_score)

    print(len(_d), null_num)
    with open(f"./data/wikihow/video2k.step_candidate_goal.rerank.{tag}.score.json", "w+", encoding="utf-8") as f:
        json.dump(_d, f, indent=2)

def create_score_for_video1k(data_tag, tag, para_file):
    null_token = '[unused2]'
    with open(f"./data/wikihow/{data_tag}.{tag}.result", "r") as f:
        rerank_d = json.load(f)
    print(len(rerank_d))

    if 'pkl' in para_file:
        with open(f"./data/wikihow/{para_file}", "rb") as f:
            para_d = pickle.load(f)
    else:
        with open(f"./data/wikihow/{para_file}", "r") as f:
            para_d = json.load(f)
    print(len(para_d))

    unlinkable = 0
    # normalize score
    for item in rerank_d:
        pred = [[k, v] for k, v in item['pred'].items() if k != 'x y z']
        k_list = [x[0] for x in pred]
        v_list = [x[1] for x in pred]
        sf_v_list = torch.softmax(torch.tensor(v_list), dim=0).tolist()
        item['pred'] = {k: v for k, v in zip(k_list, sf_v_list)}
        ss = sorted(item['pred'].items(), key=lambda x: x[1], reverse=True)
        item['pred'] = dict(ss)
        # some buggy prediction
        if item['step'].lower() in ['finished', 'finish', 'done', 'serve']:
            item['pred'] = {k: 0 for k in item['pred']}

        # unlinkable
        if ss[0][0] == null_token:
            if ss[0][1] >= 0.0:
                item['pred'] = {k: 0 for k in item['pred']}
                unlinkable += 1
            else: # recalculate the prob
                print(item['step'])
                i = k_list.index(null_token)
                k_list.pop(i)
                v_list.pop(i)
                sf_v_list = torch.softmax(torch.tensor(v_list), dim=0).tolist()
                item['pred'] = {k: v for k, v in zip(k_list, sf_v_list)}
                ss = sorted(item['pred'].items(), key=lambda x: x[1], reverse=True)
                item['pred'] = dict(ss)


    _d = {}
    for item in rerank_d:
        s = item['step']
        _d[s] = {}
        for k, v in item['pred'].items():
            _d[s][k] = {'rerank_score': float(v)}
            idx = para_d[s]['retrieved_goals'].index(k) if k != "[unused2]" else -1
            para_score = para_d[s]['retrieved_goals_similarity'][idx]
            _d[s][k]['para_score'] = float(para_score)

    print(len(_d), unlinkable)
    with open(f"./data/wikihow/{data_tag}.{tag}.score.json", "w+", encoding="utf-8") as f:
        json.dump(_d, f, indent=2)


def softmax_score(item):
    pred = [[k, v] for k, v in item['pred'].items() if k != 'x y z']
    k_list = [x[0] for x in pred]
    v_list = torch.tensor([x[1] for x in pred])
    v_list = torch.softmax(v_list, dim=0).tolist()
    item['pred'] = {k: v for k, v in zip(k_list, v_list)}
    item['pred'] = dict(sorted(item['pred'].items(), key=lambda x: x[1], reverse=True))
    return item


def sample_human_data():
    import random
    seed = 9976
    random.seed(seed)

    # with open("./data/wikihow/para_base_all_base_0.0_30_links.json", "r") as f:
    #     db_para = json.load(f)
    #
    # with open("./data/wikihow/all.org.t30.test.deterta.goal.t30.para_score.all.result", "r") as f:
    #     dr_para = json.load(f)
    #     dr_para = {x['step']: x for x in dr_para}
    #
    # with open("./data/wikihow/all.org.t30.test.deterta.goal.t30.train_null.all.result", "r") as f:
    #     dr_null = json.load(f)
    #     dr_null = {x['step']: x for x in dr_null}
    with open("./data/wikihow/para_base_all_base_0.0_30_links_howto1k.json", "r") as f:
        db_para = json.load(f)

    # with open("./data/wikihow/para_base_all_base_0.0_30_links_howto1k.json", "w+") as f:
    #     json.dump(db_para, f, indent=2)

    with open("./data/wikihow/howto1k.deberta.para_score.goal.c1.result", "r") as f:
        dr_para = json.load(f)
        dr_para = {x['step']: x for x in dr_para}

    with open("./data/wikihow/howto1k.deberta.train_null.goal.c1.result", "r") as f:
        dr_null = json.load(f)
        dr_null = {x['step']: x for x in dr_null}


    print(len(db_para), len(dr_para), len(dr_null))

    ks = list(set(dr_para.keys()).intersection(set(db_para.keys())).intersection(set(dr_null.keys())))
    select_ks = ks
    print(len(select_ks))

    sb_para = []
    sr_para = []
    sr_null = []
    for k in select_ks:
        r1 = db_para[k]
        r2 = dr_para[k]
        r2 = softmax_score(r2)
        r3 = dr_null[k]
        r3 = softmax_score(r3)

        assert r1['corresponding_goal'] == r2['goal'] and r2['goal'] == r3['goal'] or r2['goal'] is None and r3['goal'] is None, (r1['corresponding_goal'], r2['goal'], r3['goal'])

        for p in list(r2['pred'].keys()) + list(r3['pred'].keys()):
            assert p in r1['retrieved_goals'] or p == '[unused2]' or p == 'x y z', (k, p)

        # reform r1
        _r1 = {'step': k, 'gold': None}
        _r1['pred'] = {g: s for g, s in zip(r1['retrieved_goals'], r1['retrieved_goals_similarity'])}
        _r1['goal'] = r1['corresponding_goal']
        r1 = _r1
        r2['goal'] = r1['goal']
        r3['goal'] = r1['goal']

        sb_para.append(r1)
        sr_para.append(r2)
        sr_null.append(r3)

    print(len(sb_para), len(sr_para), len(sr_null))

    with open("./data/wikihow/human/howto1k.base.json", "w+", encoding="utf-8") as f:
        json.dump(sb_para, f, indent=2)

    with open("./data/wikihow/human/howto1k.deterta.para_score.goal.c1.json", "w+", encoding="utf-8") as f:
        json.dump(sr_para, f, indent=2)

    with open("./data/wikihow/human/howto1k.deterta.train_null.goal.c1.json", "w+", encoding="utf-8") as f:
        json.dump(sr_null, f, indent=2)


def sample_annotation_data():
    import pickle
    with open("./data/wikihow/para_base_all_base_0.0_30_links.json", "r") as f:
        d = json.load(f)
    print(len(d))
    s_idx = random.choices(list(range(len(d))), k=20000)
    print(len(s_idx))
    s_idx = set(s_idx)
    # find
    _d = {}
    for idx, (k, v) in enumerate(d.items()):
        if idx in s_idx:
            _d[k] = v

    print(len(_d))
    with open(f"./data/wikihow/para_base_all_base_0.0_30_links_20k.pkl", "wb") as fw:
        pickle.dump(_d, fw)

    exit(0)


    with open("./data/wikihow/step_goal.json", "r") as f:
        gs = json.load(f)
        gs = {x['task']: x['caption'] for x in gs}
        print(len(gs))
    with open("./data/howto100m/video1k/video1k.meta_map.json", "r") as f:
        gg = json.load(f)
        gg = list(gg['tid_task_map'].values())
    print(len(gg))
    select_s = []

    for g in gg:
        try:
            ss = [x if not x.endswith(".") else x[:-1].strip() for x in gs[g]]
            select_s += ss
        except KeyError:
            print(g)
    print(len(select_s))
    # find
    _d = {}
    for idx, (k, v) in enumerate(d.items()):
        if k in select_s:
            _d[k] = v

    print(len(_d))
    with open(f"./data/wikihow/para_base_all_base_0.0_30_links_howto1k.pkl", "wb") as fw:
        pickle.dump(_d, fw)

def check_diff():
    with open("./data/wikihow/videok1.step_candidate_goal.train_null00.score.json", "r") as f:
        d1 = json.load(f)
    with open("./data/wikihow/howto1k.deberta.train_null.goal.c1.score.json", "r") as f:
        d2 = json.load(f)

    print(len(d1), len(d2))
    diff = 0
    tot = 0
    for k, v in d1.items():
        if k in d2:
            v1 = sorted(v.items(), key=lambda x: x[1]['rerank_score'], reverse=True)[0][0]
            v2 = sorted(d2[k].items(), key=lambda x: x[1]['rerank_score'], reverse=True)[0][0]
            # print(v1, v2)
            if v1 != v2:
                diff += 1
            tot += 1
        else:
            print(f"miss {k}")
    print(diff, tot)
if __name__ == "__main__":
    # create_score_file("all.org.t30.test.deterta.goal.t30.train_null")
    # create_score_file("all.org.t30.test.deterta.goal.t30.para_score")
    # all_process('para_score')
    # all_process('train_null')
    # create_score_for_video1k()
    # create_score_file()
    # sample_data()
    # sample_human_data()
    # video2k_process('para_score')
    # video2k_process('train_null')
    # create_score_for_video1k('sample5k.deberta', 'para_score.goal.c1', 'para_base_all_base_0.0_30_links_5k.pkl')
    # create_score_for_video1k('sample5k.deberta', 'train_null.goal.c1', 'para_base_all_base_0.0_30_links_5k.pkl')
    # create_score_for_video1k('howto1k.deberta', 'para_score.goal.c1', 'para_base_all_base_0.0_30_links_howto1k.pkl')
    # create_score_for_video1k('howto1k.deberta', 'train_null.goal.c1', 'para_base_all_base_0.0_30_links_howto1k.pkl')
    # check_diff()
    # create_score_file("all.org.t30.test.deterta.t30.train_null.goal.c1")
    all_process('all.org.t30.test.deterta.t30', 'train_null.goal.c1.all', 'para_base_all_base_0.0_30_links.json')