# Copyright (c) <anonymized for review>

from argparse import ArgumentParser, Namespace
import logging
from pathlib import Path
import json
import copy
import os
from sqlitedict import SqliteDict
from tqdm import tqdm

from confidence import get_confidence_by_name
from lama.modules import build_model_by_name

logger = logging.getLogger(__name__)


def set_default_option_if_not_exist(args, key, value):
    if not hasattr(args, key):
        setattr(args, key, value)


def get_config_from_file(config_file: str, args=None):
    with open(config_file) as f:
        config_json = json.load(f)

    if args is None:
        args = Namespace()

    for k, v in config_json.items():
        if hasattr(args, k) and getattr(args, k) is not None:
            logger.warning(
                f"Argument {k} is specified by the command. "
                "The value in the configuration file will be ignored.")
            continue
        setattr(args, k, v)

    return args


def is_already_processed(prediction_item, metric_log_name):
    gold_answer = prediction_item.get("gold_answer")

    if gold_answer is None:
        return False

    if metric_log_name in gold_answer:
        return True
    else:
        return False
    

def main(args):
    get_config_from_file(args.config, args)

    set_default_option_if_not_exist(args, "metrics", [{
        "name": "token",
        "use_prediction_value": True
    }])
    set_default_option_if_not_exist(args, "topk", 100)
    set_default_option_if_not_exist(
        args,
        "common_vocab_filename",
        "pre-trained_language_models/common_vocab_cased.txt")
    # set_default_option_if_not_exist(args, "batch_size", 32)
    set_default_option_if_not_exist(args, "glob", None)
    set_default_option_if_not_exist(args, "n_commit", 1000)
    set_default_option_if_not_exist(args, "bert_model_dir", None)
    set_default_option_if_not_exist(args, "bert_vocab_name", "vocab.txt")

    p_db = Path(args.db)
    if args.glob is None:
        p_dbs = [p_db]
    else:
        p_dbs = list(p_db.glob(args.glob))

    for p_db in p_dbs:
        logger.info(f"Processing {p_db}")

        db = SqliteDict(str(p_db))
        first_item = next(db.values())

        for metric_config in args.metrics:
            metric_name = metric_config["name"]
            metric_log_name = (
                metric_config["log_name"]
                if metric_config.get("log_name") is not None
                else f"metric_{metric_name}"
            )

            # # Skip if the metric has already been calculated
            # # first_item = next(db.values())
            # if is_already_processed(first_item, metric_log_name):
            #     logger.info(f"The metric {metric_name} has already calculated on this database. "
            #         "The calculation will be skipped.")
            #     continue

            # Prepare scorer
            metric_args = Namespace(**metric_config)
            gen_args = Namespace(**args.LM)
            gen_args.common_vocab_filename = args.common_vocab_filename
            gen_args.bert_model_dir = args.bert_model_dir
            gen_args.bert_vocab_name = args.bert_vocab_name
            metric_args.generation = gen_args
            scorer = get_confidence_by_name(metric_name, args=metric_args)

            for item_id, (db_key, prediction_item) in enumerate(tqdm(db.items())):
                sample = prediction_item["sample"]
                predictions = prediction_item["masked_topk"]["topk"]
                gold_answer = (
                    prediction_item["gold_answer"]
                    if prediction_item.get("gold_answer") is not None
                    else {
                        "token_word_form": sample["obj_label"],
                        "token_idx": prediction_item["label_index"][0]
                    }
                )

                topk_predictions = predictions[:args.topk]
                topk_predictions.append(gold_answer)
                scores = scorer.get_metrics(sample, topk_predictions)

                for i, score in enumerate(scores[:-1]):
                    predictions[i][metric_log_name] = score
                gold_answer[metric_log_name] = scores[-1]

                prediction_item["gold_answer"] = gold_answer

                db[db_key] = prediction_item
                if (item_id + 1) / args.n_commit == 0:
                    db.commit()

            db.commit()

        db.close()


def get_parser():
    parser = ArgumentParser()
    parser.add_argument(
        "-c", "--config", required=True,
        help="Path to config file"
    )
    parser.add_argument(
        "--db", required=True,
        help="Path to the database (result.sqlite). "
        "If --glob is given, this argument is regarded as the root directory for glob search."
    )
    parser.add_argument(
        "--glob",
        help="Glob expression to find and process multiple dbs."
    )
    parser.add_argument(
        "--n_commit", type=int, default=1000,
        help="Number of iteration for a commit to DB."
    )
    return parser



if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)

    parser = get_parser()
    args = parser.parse_args()

    main(args)
