import os
import json
import numpy as np
import matplotlib.pyplot as plt

def calc_recall(n, tag, null_th):
    org_file = f"/projects/[SERVER]2/users/[USER]/para/wikihow/para_base_sz_all_base_0.0_-1.0_10.0_{n}_links.json"
    with open(org_file, "r") as f:
        org_res = json.load(f)

    res_file = f"/projects/[SERVER]2/users/[USER]/wikihow/wikihow/gold.rerank.org.t{n}.test.{tag}.result"
    print(org_file)
    print(res_file)
    score = []
    neg_score = []
    pos_score = []
    with open(res_file, "r") as f:
        rerank_res = json.load(f)
        acc = 0
        r10 = 0
        org_ub = 0 # upper bound
        org_acc = 0
        null_num = 1e-10
        null_tp = 0
        null_fn = 0
        null_fp = 0
        for item in rerank_res:
            step = item['step']
            pred = item['pred']
            pred = sorted(pred.items(), key=lambda x: x[1], reverse=True)
            if org_res[step]['retrieved_goal_rank'] == -1:
                null_num += 1
                if "train_null" in res_file:
                    assert item['gold'] == '[unused2]'
                    if pred[0][0] == item['gold']:
                        null_tp += 1
                    else:
                        null_fn += 1
                else:
                    if pred[0][1] <= null_th:
                        null_tp += 1
                    else:
                        null_fn += 1
                continue

            org_ub += 1
            org_acc += int(org_res[step]['retrieved_goal_rank'] == 0)
            if pred[0][0] == item['gold'] and pred[0][1] >= null_th:
                acc += 1
            if item['gold'] in [x[0] for x in pred[:10]]:
                r10 += 1
            score.append(pred[0][1])
            if pred[0][0] != item['gold']:
                neg_score.append(pred[0][1])
            else:
                pos_score.append(pred[0][1])

            if "train_null" in res_file:
                if pred[0][0] == '[unused2]':
                    null_fp += 1
            elif pred[0][1] <= null_th:
                null_fp += 1

        assert null_tp + null_fn + org_ub == len(rerank_res)
        print("original top1 and top30: ", org_acc, org_acc / len(rerank_res), org_ub, org_ub / len(rerank_res))
        print("rerank top1: ", acc, len(rerank_res), acc / len(rerank_res))
        print("rerank top10: ", r10, len(rerank_res), r10 / len(rerank_res))
        # if 'train_null' in res_file:
        print("null info: ", null_tp, null_fn, null_fp)
        print("null recall/precision: ", null_tp / (null_tp + null_fn + 1e-10) * 100, null_tp / (null_tp + null_fp + 1e-10) * 100)

    return score, neg_score, pos_score

def sanity_check(n, tag):
    rerank_d = []
    for i in range(10):
        # convert para data to data augmentation
        with open(f"/projects/[SERVER]2/users/[USER]/wikihow/wikihow/all.org.t30.test.bert.goal.t30.{i}.result", "r") as f:
            cur_d = json.load(f)
            rerank_d += cur_d
    rerank_d = {x['step']: x for x in rerank_d}
    print(len(rerank_d))


    org_file = f"/projects/[SERVER]2/users/[USER]/para/wikihow/para_base_sz_all_base_0.0_-1.0_10.0_{n}_links.json"
    with open(org_file, "r") as f:
        org_res = json.load(f)

    res_file = f"/projects/[SERVER]2/users/[USER]/wikihow/wikihow/gold.rerank.org.t{n}.test.{tag}.result"
    score = []
    neg_score = []
    pos_score = []
    with open(res_file, "r") as f:
        rerank_res = json.load(f)
        acc = 0
        org_ub = 0  # upper bound
        org_acc = 0
        for item in rerank_res:
            step = item['step']
            if org_res[step]['retrieved_goal_rank'] == -1:
                continue
            org_ub += 1
            org_acc += int(org_res[step]['retrieved_goal_rank'] == 0)
            pred = item['pred']
            pred = sorted(pred.items(), key=lambda x: x[1], reverse=True)
            if pred[0][0] == item['gold'] and pred[0][1] >= 0:
                if rerank_d[step]['pred'][item['gold']] == max(rerank_d[step]['pred'].values()):
                    acc += 1
            score.append(pred[0][1])
            if pred[0][0] != item['gold']:
                neg_score.append(pred[0][1])
            else:
                pos_score.append(pred[0][1])
        print(acc, len(rerank_res), acc / len(rerank_res))
        print(org_acc, org_acc / len(rerank_res), org_ub, org_ub / len(rerank_res))

    return score, neg_score, pos_score


def plot(nscore, pscore):
    fig, ax = plt.subplots(nrows=1, ncols=1)
    inter = 0.02
    bins = np.arange(0, 1 + inter, inter)
    xs = np.arange(bins.shape[0] - 1)

    histn, _ = np.histogram(nscore, bins=bins)
    histp, _ = np.histogram(pscore, bins=bins)
    ax.bar(xs, histp/(histn + histp + 1e-10), color='r', alpha=0.5)
    ax.bar(xs, histn/(histn + histp + 1e-10), color='g', alpha=0.5)

    plt.show()

if __name__ == "__main__":
    # n = 10
    # tag = 'base'
    # print(f"***{n}, {tag}***")
    # s2, ns2 = calc_recall(n, tag)
    #
    # n = 10
    # tag = 'goal'
    # print(f"***{n}, {tag}***")
    # s2, ns2 = calc_recall(n, tag)
    #
    # n = 10
    # tag = 'goal.t30'
    # print(f"***{n}, {tag}***")
    # s2, ns2 = calc_recall(n, tag)
    #
    # n = 30
    # tag = "bert.goal"
    # print(f"***{n}, {tag}***")
    # s2, ns2 = calc_recall(n, tag)
    # # plot(ns2)

    # n = 30
    # tag = "bert.goal.t30"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)
    # plot(ns2, ps2)

    # n = 30
    # tag = "bert.goal"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)

    #######################################################
    # n = 30
    # tag = "bert.base"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)
    #
    #
    # n = 30
    # tag = "deberta.para_score.base"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)
    #
    # n = 30
    # tag = "deberta.goal.t30.para_score"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)
    #
    # n = 30
    # tag = "deberta.para_score.c1"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)

    n = 30
    tag = "deberta.para_score.goal.c1"
    print(f"***{n}, {tag}***")
    s2, ns2, ps2 = calc_recall(n, tag, 0.9)

    # n = 30
    # tag = "deberta.para_score.train_null.goal.c1"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0)

    n = 30
    tag = "deberta.train_null.goal.c1"
    print(f"***{n}, {tag}***")
    s2, ns2, ps2 = calc_recall(n, tag, 0)


    # n = 30
    # tag = "deberta.goal.t30.train_null.ep2"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, None)

    #######################################################

    # #
    # n = 30
    # tag = "deberta.goal.t30.para_score"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0.95)
    #
    # n = 10
    # tag = "deberta.goal.t30"
    # print(f"***{n}, {tag}***")
    # s2, ns2, ps2 = calc_recall(n, tag, 0.95)

    # sanity_check(n, tag)