import os
import sys
import json
from collections import defaultdict
os.chdir("/home/[USER]/workshop/wikihow")
sys.path.append("/home/[USER]/workshop/wikihow")
import random
import json
import pickle
from collections import Counter
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
# import plotille
import shutil
from ir_eval import calc_average_precision, calc_mean_rank, calc_mrr, calc_mean_average_precision

def extract_good_goals():
    task_id_task_map = {}
    with open("./data/howto100m/task_ids.csv", "r") as f:
        for line in f:
            tks = line.strip().split("\t")
            task_id_task_map[int(tks[0])] = tks[1]
    with open("./data/howto100m/task_id_task_map.json", "w+") as f:
        json.dump(task_id_task_map, f, indent=2)

    task_video_map = defaultdict(list)
    with open("./data/howto100m/HowTo100M_v1.csv", "r") as f:
        f.readline()
        for line in f:
            tks = line.strip().split(",")
            if int(tks[-2]) <= 50:
                task_video_map[task_id_task_map[int(tks[-1])]].append([tks[0], int(tks[-2])])

    task_video_map = {k: v for k, v in task_video_map.items() if len(v) >= 20}
    sampled_k = list(task_video_map.keys())
    random.shuffle(sampled_k)
    sampled_k = sampled_k[:2000]

    _task_video_map = {}
    tot_v = 0
    for k, v in task_video_map.items():
        if k not in sampled_k:
            continue
        _task_video_map[k] = v
        tot_v += len(_task_video_map[k])
    print(len(_task_video_map), tot_v)
    with open("./data/howto100m/video2k/all_task_video_map.json", "w+") as f:
        json.dump(_task_video_map, f, indent=2)
    exit(0)



    tot = 0
    train_tot = 0
    dev_tot = 0
    test_tot = 0
    _task_video_map = {'train': {}, 'dev': {}, 'test': {}}
    for k, v in task_video_map.items():
        if len(v) <= 20:
            continue
        random.shuffle(v)
        v = [x[0] for x in v]
        _train = v[:int(len(v) / 3)]
        _dev = v[int(len(v) / 3): int(len(v) * 2 / 3)]
        _test = v[int(len(v)  * 2 / 3):]
        train_tot += len(_train)
        dev_tot += len(_dev)
        test_tot += len(_test)
        _task_video_map['train'][k] = _train
        _task_video_map['dev'][k] = _dev
        _task_video_map['test'][k] = _test
        for x in _task_video_map['dev'][k] + _task_video_map['test'][k]:
            assert x not in _task_video_map['train'][k]

    print(train_tot, dev_tot, test_tot)
    with open("./data/howto100m/video2k/split_task_video_map.json", "w+") as f:
        json.dump(_task_video_map, f, indent=2)

def merge_all_split():
    d = {'train': {}, 'dev': {}, 'test': {}}
    with open("./data/howto100m/onek_meta_data.p", "rb") as f:
        data = pickle.load(f)

    for k, v in data.items():
        d['train'][v['text']['task']] = v['videos']['train']
        d['dev'][v['text']['task']] = v['videos']['val']
        d['test'][v['text']['task']] = v['videos']['test']
    with open("./data/howto100m/video1k/split_task_video_map.json", "w+") as f:
        json.dump(d, f, indent=2)

def calc_unigram():
    with open("./data/howto100m/video2k/all_task_video_caption.json", "r") as f:
        d = json.load(f)
    c = Counter()
    for k, v in tqdm(d.items()):
        tks = v.lower().split()
        for t in tks:
            c[t] += 1
    print(len(c))
    c = sorted(c.items(), key=lambda x: x[1], reverse=True)

    for t in c[:1000]:
        print(t)

def calc_len():
    with open("./data/howto100m/video2k/all_task_video_caption.json", "r") as f:
        d = json.load(f)
    ll = []
    for k, v in tqdm(d.items()):
        ll.append(len(v.split()))

    print(plotille.histogram(np.array(ll)))
    # fig, ax = plt.subplots(nrows=1, ncols=1)
    # interval = 100
    # bins = np.arange(0, 5000, interval)
    # hist1, _ = np.histogram(ll, bins=bins)
    # xs = np.arange(bins.shape[0] - 1)
    # ax.bar(xs, hist1, alpha=0.5, color='b')
    # print(hist1 / len(ll))
    # print(sum(hist1[:6] / len(ll)))
    # plt.show()

def resplit():
    top_n = 150
    cap = {}
    for s in ['train', 'dev', 'test']:
        with open(f"./data/howto100m/video1k/video1k.{s}.caption.json", "r") as f:
            dd = json.load(f)
            cap = {**dd, **cap}
    print(len(cap))

    with open("./data/howto100m/HowTo100M_v1.csv", "r") as f:
        f.readline()
        v_rank = {}
        for line in f:
            tks = line.strip().split(",")
            v_rank[tks[0]] = int(tks[-2])

    cap = {k: v for k, v in cap.items() if v_rank[k] <= top_n}
    print(len(cap))

    with open("./data/howto100m/video1k/split_task_video_map.json", "r") as f:
        org_split = json.load(f)

    new_split = {'train': {}, 'dev': {}, 'test': {}}
    train_cap = {}
    dev_cap = {}
    test_cap = {}
    for task in org_split['train'].keys():
        task_vid = org_split['train'][task] +  org_split['dev'][task] +org_split['test'][task]
        task_vid = [x for x in task_vid if x in cap]
        random.shuffle(task_vid)
        ll = len(task_vid)
        _train = task_vid[: int(ll * 0.75)]
        _dev = task_vid[int(ll * 0.75): int(ll * 0.875)]
        _test = task_vid[int(ll * 0.875):]
        new_split['train'][task] = _train
        new_split['dev'][task] = _dev
        new_split['test'][task] = _test
        train_cap = {**train_cap, **{k: cap[k] for k in _train}}
        dev_cap = {**dev_cap, **{k: cap[k] for k in _dev}}
        test_cap = {**test_cap, **{k: cap[k] for k in _test}}

    for k, v in new_split.items():
        tot = sum([len(vv) for vv in v.values()])
        print(k, len(v), tot)

    with open(f"./data/howto100m/video1k/t{top_n}_resplit/split_task_video_map.json", "w+") as f:
        json.dump(new_split, f, indent=2)


    with open(f"./data/howto100m/video1k/t{top_n}_resplit/video1k.train.caption.json", "w+") as f:
        print(len(train_cap))
        json.dump(train_cap, f, indent=2)

    with open(f"./data/howto100m/video1k/t{top_n}_resplit/video1k.dev.caption.json", "w+") as f:
        print(len(dev_cap))
        json.dump(dev_cap, f, indent=2)

    with open(f"./data/howto100m/video1k/t{top_n}_resplit/video1k.test.caption.json", "w+") as f:
        print(len(test_cap))
        json.dump(test_cap, f, indent=2)

def merge_results():
    tag = "all_goal_useful_step.mr.exp.m15.3"
    d = {'step_map': {}, 'info': {}}
    n = 50
    for i in range(n):
        with open(f"./data/howto100m/video1k/t150_resplit/{tag}.{i}.json", "r") as f:
            _d = json.load(f)
        d = {'step_map': {**d['step_map'], **_d['step_map']}, 'info': {**d['info'], **_d['info']}}
    print(len(d))
    print(len(d['step_map']))
    print(len(d['info']))
    with open(f"./data/howto100m/video1k/t150_resplit/{tag}.json", "w+") as f:
        json.dump(d, f, indent=2)

    a = input("everything corret? ")
    if a == 'y':
        for i in range(n):
            os.remove(f"./data/howto100m/video1k/t150_resplit/{tag}.{i}.json")


def sanity_check():
    hv = 'ap'
    r_size = 200

    folder = "./data/howto100m/video1k/t150_resplit"
    with open(f"{folder}/all_goal_useful_step.ap.exp.m15.json", "r") as f:
        d = json.load(f)['info']
    with open(f"{folder}/retrieved.goal_only.video_1k_t150_resplit_train.json", "r") as f:
        goal_only = json.load(f)
    with open(f"{folder}/retrieved.filter.base.exp.ap.exp.video_1k_t150_resplit_train.1.0.json", "r") as f:
        exp = json.load(f)


    for goal, only_pred in goal_only.items():
        exp_pred = exp[goal]
        only_metric_1 = d[goal][0]
        exp_metric_1 = d[goal][1]
        only_metric_2 = -calc_average_precision(only_pred, r_size)
        exp_metric_2 = -calc_average_precision(exp_pred, r_size)
        if str(only_metric_1) == 'nan':
            print(goal)
            continue
        assert only_metric_1 == only_metric_2, (only_metric_1, only_metric_2)
        assert exp_metric_1 == exp_metric_2, (exp_metric_1, exp_metric_2)

def metrics():
    tag = 'mr'
    r_size = 200
    folder = "./data/howto100m/video1k/t150_resplit"
    with open(f"{folder}/all_goal_useful_step.{tag}.exp.m15.json", "r") as f:
        d = json.load(f)['step_map']
        # tar_goal = [k for k, v in d.items() if len(v) != 0]
        tar_goal = [k for k, v in d.items()]
    print(len(tar_goal))


    for i in np.arange(0.1, 1.0, 0.1):
        data = [f'test.{i:.1f}.None', f'test.{i:.1f}.None']
        # data = [f'dev.{0.6:.1f}.None', f'dev.{0.4:.1f}.None']
        # data = [f'test.{0.6:.1f}.None', f'test.{0.4:.1f}.None']
        print(f"==========={data}==========")
        with open(f"{folder}/retrieved.filter.base.exp.{tag}.base.video_1k_t150_resplit_{data[0]}.json", "r") as f:
            base_exp = json.load(f)
            base_exp = {k: v for k, v in base_exp.items() if k in tar_goal}
        with open(f"{folder}/retrieved.filter.base.exp.{tag}.exp.video_1k_t150_resplit_{data[1]}.json", "r") as f:
            exp = json.load(f)
            exp = {k: v for k, v in exp.items() if k in tar_goal}
        assert len(base_exp) == len(exp)


        base_map = calc_mean_average_precision(base_exp, r_size)
        exp_map = calc_mean_average_precision(exp, r_size)
        print("base map: ", base_map)
        print("exp map: ", exp_map)

        base_counter = Counter()
        exp_counter = Counter()
        for k, v in base_exp.items():
            for vv in v['gold'].values():
                base_counter[vv] += 1
        for k, v in exp.items():
            for vv in v['gold'].values():
                exp_counter[vv] += 1
        base_mr = calc_mean_rank(base_counter, False)
        exp_mr = calc_mean_rank(exp_counter, False)
        print("base mr: ", base_mr)
        print("exp mr: ", exp_mr)

        recall_n = {1: 0, 3: 0, 5: 0, 10: 0, 15: 0, 20: 0, 25: 0, 30: 0}
        precision_n = {1: 0, 3: 0, 5: 0, 10: 0, 15: 0, 20: 0, 25: 0, 30: 0}
        for k, v in base_exp.items():
            for tk in recall_n.keys():
                cur_hit = sum(x <= tk for x in v['gold'].values())
                recall_n[tk] = recall_n[tk] + cur_hit / len(v['gold'])
                precision_n[tk] = precision_n[tk] + cur_hit / tk
        recall_n = {k: f"{v / len(base_exp): .4f}" for k, v in recall_n.items()}
        print("base recall: ", recall_n)

        recall_n = {1: 0, 3: 0, 5: 0, 10: 0, 15: 0, 20: 0, 25: 0, 30: 0}
        precision_n = {1: 0, 3: 0, 5: 0, 10: 0, 15: 0, 20: 0, 25: 0, 30: 0}
        for k, v in exp.items():
            for tk in recall_n.keys():
                cur_hit = sum(x <= tk for x in v['gold'].values())
                recall_n[tk] = recall_n[tk] + cur_hit / len(v['gold'])
                precision_n[tk] = precision_n[tk] + cur_hit / tk
        recall_n = {k: f"{v / len(base_exp): .4f}" for k, v in recall_n.items()}
        print("exp recall: ", recall_n)


if __name__ == "__main__":
    # extract_good_goals()
    # calc_unigram()
    # calc_len()
    # resplit()
    # merge_results()
    metrics()