from sqlitedict import SqliteDict
from pathlib import Path
import numpy as np
import math
from tqdm import tqdm
from sklearn import metrics
import json
import itertools
from argparse import ArgumentParser, Namespace
import logging

logger = logging.getLogger(__name__)


def get_recursive_key(key, dic):
    keys = key.split(".")
    for key_cur in keys:
        val = dic[key_cur]
        dic = val
    return val


def add_rank_scores(predictions, rank_key, rank_name, lower_is_better=False):
    # Assumption: the search space for the rank scores is the whole prediction list (i.e. exposure_topk = len(predictions))

    # TODO: handle tied items appropriately
    scores = np.array([get_recursive_key(rank_key, x) for x in predictions])
    if not lower_is_better:
        scores = (-1) * scores
    reranked_index_list = np.argsort(scores)
    idx_to_rerank_t = dict([(v, i) for i, v in enumerate(reranked_index_list)])

    for i, prediction in enumerate(predictions):
        prediction[rank_name] = idx_to_rerank_t[i] + 1


def add_gap_scores(predictions, gap_key, gap_name):
    # Gap between i-th and i+1-th predictions divided by its rank
    values = np.array([p[gap_key] for p in predictions])
    sorted_indices = np.argsort((-1) * values)
    for i in range(len(predictions) - 1):
        idx1 = sorted_indices[i]
        idx2 = sorted_indices[i + 1]
        val1 = values[idx1]
        val2 = values[idx2]
        predictions[idx1][gap_name] = (val1 - val2) / (i + 1)
    predictions[sorted_indices[-1]][gap_name] = 0


def exposure_score(prediction, rank_name, exposure_topk):
    score = math.log2(exposure_topk) - math.log2(prediction[rank_name])
    return score


def create_ranking_confidence_pairs(metrics_for_ranking, metrics_for_confidence):
    metrics_for_ranking_list = []
    metrics_for_confidence_list = []

    for level in range(1, metrics_for_ranking["combine_level"] + 1):
        combinations = itertools.combinations(metrics_for_ranking["metrics"], level)
        for comb in combinations:
            metrics_for_ranking_list.append(comb)

    for level in range(1, metrics_for_confidence["combine_level"] + 1):
        combinations = itertools.combinations(metrics_for_confidence["metrics"], level)
        for comb in combinations:
            metrics_for_confidence_list.append(comb)

    return [
        (r,c)
        for r in metrics_for_ranking_list
        for c in metrics_for_confidence_list
    ]


def get_prediction(
    sample,
    ranking_metrics,
    confidence_metrics,
    topk=1,
    prefilter_topk=100,
):

    predictions = sample["masked_topk"]["topk"][:]

    if prefilter_topk is not None:
        predictions = predictions[:prefilter_topk]

    # Compute scores for ranking and sort the prediction
    for prediction in predictions:
        ranking_score = 0.0
        for metric in ranking_metrics:
            ranking_score += metric["weight"] * get_recursive_key(metric["name"], prediction)
        prediction["ranking_score"] = ranking_score

    predictions = sorted(
        predictions,
        key=lambda x: (-1) * x["ranking_score"]
    )

    if topk is not None:
        predictions = predictions[:topk]

    for prediction in predictions:
        confidence_score = 0.0
        for metric in confidence_metrics:
            confidence_score += metric["weight"] * get_recursive_key(metric["name"], prediction)
        prediction["confidence_score"] = confidence_score

    # re-index the results by the rank (just in case the input was not in the right order)
    for idx, result in enumerate(predictions):
        result["i"] = idx

    return predictions


def risk_coverage_auc(is_correct, confidences):
    assert len(is_correct) == len(confidences)
    n_instances = len(confidences)
    tuples = [(cr, cf) for (cr, cf) in zip(is_correct, confidences)]
    sorted_tuples = sorted(tuples, key=lambda x: -x[1])

    coverages = []
    risks = []
    n_cover = 0
    n_error = 0
    for cr, cf in sorted_tuples:
        n_cover += 1
        if cr == 0:
            n_error += 1
        coverages.append(n_cover / n_instances)
        risks.append(n_error / n_cover)
    auc = metrics.auc(coverages, risks)

    return auc, risks, coverages


def get_risk_coverage_auc(
    samples,
    ranking_metrics, 
    confidence_metrics, 
    topk=1, 
    prefilter_topk=100,
    disable_tqdm=False,
):
    assert topk == 1
    is_correct_all = []
    confidence_scores_all = []
    for sample in tqdm(samples, disable=disable_tqdm):
        prediction = get_prediction(
            sample, 
            ranking_metrics, 
            confidence_metrics,
            topk=topk,
            prefilter_topk=prefilter_topk,
        )
        is_correct = [
            int(x["token_idx"] in sample["label_index"]) for x in prediction
        ]
        confidence_scores = [x["confidence_score"] for x in prediction]
        is_correct_all.extend(is_correct)
        confidence_scores_all.extend(confidence_scores)
    auc, risks, coverages = risk_coverage_auc(is_correct_all, confidence_scores_all)
    return auc, risks, coverages, is_correct_all, confidence_scores_all


def set_default_option_if_not_exist(args, key, value):
    if not hasattr(args, key):
        setattr(args, key, value)


def get_config_from_file(config_file: str, args=None):
    with open(config_file) as f:
        config_json = json.load(f)

    if args is None:
        args = Namespace()

    for k, v in config_json.items():
        if hasattr(args, k) and getattr(args, k) is not None:
            logger.warning(
                f"Argument {k} is specified by the command. "
                "The value in the configuration file will be ignored.")
            continue
        setattr(args, k, v)

    return args


def evaluate_results(root_dir, config_file, glob_string="**/result.sqlite"):
    config = get_config_from_file(config_file)

    set_default_option_if_not_exist(config, "topk", 1)
    set_default_option_if_not_exist(config, "prefilter_topk", 100)
    set_default_option_if_not_exist(config, "rank_metrics", [])
    set_default_option_if_not_exist(config, "gap_metrics", [])
    set_default_option_if_not_exist(config, "exposure_metrics", [])
    set_default_option_if_not_exist(config, "commit_metrics_to_db", False)

    db_flag = "r"
    if config.commit_metrics_to_db:
        db_flag = "c"

    all_samples = []
    logger.info("Loading data...")
    for p_result_file in tqdm(Path(root_dir).glob(glob_string)):
        db = SqliteDict(str(p_result_file), flag=db_flag)
        samples = []

        # Add rank, gap and exposure scores
        for db_id, sample in db.items():
            for metric in config.rank_metrics:
                add_rank_scores(
                    sample["masked_topk"]["topk"][:config.prefilter_topk],
                    metric["key"],
                    metric["name"]
                )

            for metric in config.gap_metrics:
                add_gap_scores(
                    sample["masked_topk"]["topk"][:config.prefilter_topk],
                    metric["key"],
                    metric["name"],
                )

            for metric in config.exposure_metrics:
                for prediction in sample["masked_topk"]["topk"][:config.prefilter_topk]:
                    prediction[metric["name"]] = exposure_score(
                        prediction,
                        metric["key"],
                        metric["exposure_topk"]
                    )

            samples.append(sample)

            if config.commit_metrics_to_db:
                db[db_id] = sample
        if config.commit_metrics_to_db:
            db.commit()

        # samples = list(db.values())
        for sample in samples:
            sample["masked_topk"]["topk"] = sample["masked_topk"]["topk"][:config.prefilter_topk]
        all_samples.extend(samples)

        db.close()

    ranking_confidence_pairs = create_ranking_confidence_pairs(
        config.metrics_for_ranking,
        config.metrics_for_confidence
    )

    logger.debug(f"ranking_confidence_pairs: {ranking_confidence_pairs}")

    evaluation_results = []
    for ranking_metrics, confidence_metrics in tqdm(ranking_confidence_pairs):
        if not config.metrics_for_ranking["search_weight"]:
            weighted_ranking_metrics = [
                    [
                    {"name": r, "weight": 1}
                    for r in ranking_metrics
                ]
            ]
        else:
            raise RuntimeError("Search_weight is not implemented yet.")
        if not config.metrics_for_confidence["search_weight"]:
            weighted_confidence_metrics = [
                [
                    {"name": c, "weight": 1}
                    for c in confidence_metrics
                ]
            ]
        for rm in weighted_ranking_metrics:
            for cm in weighted_confidence_metrics:
                logger.debug(f"rm: {rm}")
                logger.debug(f"cm: {cm}")
                (
                    rc_auc,
                    risks,
                    coverages,
                    is_correct_all,
                    confidence_scores_all
                ) = get_risk_coverage_auc(all_samples, rm, cm)

                evaluation_results.append({
                    "ranking_metrics": rm,
                    "confidence_metrics": cm,
                    "RC-AUC": rc_auc,
                    "risks": risks,
                    "coverages": coverages,
                })

    return evaluation_results


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    parser = ArgumentParser()
    parser.add_argument("root_dir", help="Directory cotaining results")
    parser.add_argument(
        "--config",
        required=True,
        help="Path to config file"
    )
    parser.add_argument(
        "--glob",
        default="**/result.sqlite",
        help="glob expression for finding result files"
    )
    args = parser.parse_args()

    results = evaluate_results(args.root_dir, args.config, args.glob)

    for result in results:
        print(json.dumps(result))
