# Copyright (c) <anonymized for review>

from argparse import ArgumentParser
from sqlitedict import SqliteDict
from pathlib import Path
from tqdm import tqdm
import logging

logger = logging.getLogger(__name__)


metric_keys = [
    "metric_token",
    "metric_sent-norm",
    "gap_token",
    # "exposure_sent_top100",
    "metric_dropout.mean",
    "metric_dropout.neg_var",
]


def get_recursive_key(key, dic):
    keys = key.split(".")
    for key_cur in keys:
        val = dic[key_cur]
        dic = val
    return val


def get_top_prediction_by_key(predictions, metric_key):
    predictions = sorted(
        predictions,
        key=lambda x: get_recursive_key(metric_key, x),
        reverse=True
    )
    return predictions[0]


def main(args):
    results = dict([(k, []) for k in metric_keys])

    for p in Path(args.root).glob(args.glob):
        logger.info(p)
        db = SqliteDict(p)
        for sample in tqdm(db.values()):
            for metric_key in metric_keys:
                top_prediction = get_top_prediction_by_key(sample["masked_topk"]["topk"][:100], metric_key)
                is_correct = 1 if top_prediction["token_word_form"] == sample["sample"]["obj_label"] else 0
                results[metric_key].append(is_correct)
        db.close()

    for metric_key in metric_keys:
        P_AT_1 = 0.
        for is_correct in results[metric_key]:
            if is_correct:
                P_AT_1 += 1.
        P_AT_1 /= len(results[metric_key])

        print(f"{metric_key}: {P_AT_1}")


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)

    parser = ArgumentParser()
    parser.add_argument("root", help="DB root")
    parser.add_argument("--glob", default="**/result.sqlite")
    args = parser.parse_args()

    main(args)