# Copyright (c) <anonymized for review>

from sqlitedict import SqliteDict
from pathlib import Path
from tqdm import tqdm
import json


def get_prediction(
    sample,
    ranking_metrics,
    confidence_metrics,
    topk=1,
    prefilter_topk=100,
):

    predictions = sample["masked_topk"]["topk"][:]

    if prefilter_topk is not None:
        predictions = predictions[:prefilter_topk]

    # Compute scores for ranking and sort the prediction
    for prediction in predictions:
        ranking_score = 0.0
        for metric in ranking_metrics:
            ranking_score += metric["weight"] * prediction[metric["name"]]
        prediction["ranking_score"] = ranking_score

        confidence_score = 0.0
        for metric in confidence_metrics:
            confidence_score += metric["weight"] * prediction[metric["name"]]
        prediction["confidence_score"] = confidence_score

    predictions = sorted(
        predictions,
        key=lambda x: (-1) * x["ranking_score"]
    )
        
    if topk is not None:
        predictions = predictions[:topk]
    
    # re-index the results by the rank (just in case the input was not in the right order)
    for idx, result in enumerate(predictions):
        result["i"] = idx
    
    return predictions


def get_optimal_threshold_index(is_correct_all, confidence_scores_all):
    tuples = [(cr, cf) for (cr, cf) in zip(is_correct_all, confidence_scores_all)]
    sorted_tuples = sorted(tuples, key=lambda x: -x[1])

    max_diff_ind = 0
    max_diff = 0
    cur_diff = 0
    for i, (cr, cf) in enumerate(sorted_tuples):
        # print(i, cr, cur_diff)
        cur_diff += 2 * cr - 1
        if cur_diff > max_diff:
            max_diff_ind = i
            max_diff = cur_diff
    return max_diff_ind, max_diff, sorted_tuples


ranking_metrics = [
    {"name": "metric_token", "weight": 1.0},
]
confidence_metrics = [
    {"name": "metric_token", "weight": 1.0},
]

results = {}

db_root_and_glob = [
    ("gre_bb", "expr/output/Google_RE/results/bert_base/", "**/result.sqlite"),
    ("trex_bb", "expr/output/TREx/results/bert_base/", "**/result.sqlite"),
    ("cnet_bb", "expr/output/ConceptNet/results/bert_base/", "**/result.sqlite"),
    ("squad_bb", "expr/output/Squad/results/bert_base/", "**/result.sqlite"),
    ("all_bb", "expr/output/", "**/bert_base/**/result.sqlite"),
    ("gre_bl", "expr/output/Google_RE/results/bert_large/", "**/result.sqlite"),
    ("trex_bl", "expr/output/TREx/results/bert_large/", "**/result.sqlite"),
    ("cnet_bl", "expr/output/ConceptNet/results/bert_large/", "**/result.sqlite"),
    ("squad_bl", "expr/output/Squad/results/bert_large/", "**/result.sqlite"),
    ("all_bl", "expr/output/", "**/bert_large/**/result.sqlite"),
]

for name, root, glob in db_root_and_glob:
    print(name)
    all_samples = []
    for p_db in Path(root).glob(glob):
        db = SqliteDict(p_db)
        for sample in tqdm(db.values()):
            sample["masked_topk"]["topk"] = sample["masked_topk"]["topk"][:100]
            all_samples.append(sample)
        db.close()

    is_correct_all = []
    confidence_scores_all = []
    for sample in tqdm(all_samples):
        prediction = get_prediction(
            sample, 
            ranking_metrics, 
            confidence_metrics,
            topk=1,
            prefilter_topk=100,
        )
        is_correct = [
            int(x["token_idx"] in sample["label_index"]) for x in prediction
        ]
        confidence_scores = [x["confidence_score"] for x in prediction]
        is_correct_all.extend(is_correct)
        confidence_scores_all.extend(confidence_scores)

    max_diff_ind, max_diff, sorted_tuples = get_optimal_threshold_index(is_correct_all, confidence_scores_all)
    print(f"max_diff_ind: {max_diff_ind} / {len(sorted_tuples)}")
    print(f"max_diff: {max_diff}")
    print(f"threthold: {sorted_tuples[max_diff_ind][1]}")

    results[name] = {
        "max_diff_ind": max_diff_ind,
        "num_samples": len(sorted_tuples),
        "max_diff": max_diff,
        "threshold": sorted_tuples[max_diff_ind][1]
    }


print(json.dumps(results))
