# Copyright (c) <anonymized for review>

import numpy as np
import pandas as pd
from scipy.optimize import minimize
from sqlitedict import SqliteDict
import logging
from argparse import ArgumentParser, Namespace
import json
from pathlib import Path
from tqdm import tqdm

from scripts.eval_auc import get_risk_coverage_auc, add_rank_scores, add_gap_scores, exposure_score

logger = logging.getLogger(__name__)


def set_default_option_if_not_exist(args, key, value):
    if not hasattr(args, key):
        setattr(args, key, value)


def get_config_from_file(config_file: str, args=None):
    with open(config_file) as f:
        config_json = json.load(f)

    if args is None:
        args = Namespace()

    for k, v in config_json.items():
        if hasattr(args, k) and getattr(args, k) is not None:
            logger.warning(
                f"Argument {k} is specified by the command. "
                "The value in the configuration file will be ignored.")
            continue
        setattr(args, k, v)

    return args


def main(args):
    config = get_config_from_file(args.config, args)

    set_default_option_if_not_exist(config, "topk", 1)
    set_default_option_if_not_exist(config, "prefilter_topk", 100)
    set_default_option_if_not_exist(config, "rank_metrics", [])
    set_default_option_if_not_exist(config, "gap_metrics", [])
    set_default_option_if_not_exist(config, "exposure_metrics", [])

    all_samples = []
    logger.info("Loading data...")
    for p_result_file in tqdm(Path(args.root_dir).glob(args.glob)):
        db = SqliteDict(str(p_result_file))

        # Add rank, gap and exposure scores
        for db_id, sample in db.items():
            sample["masked_topk"]["topk"] = sample["masked_topk"]["topk"][:config.prefilter_topk]

            for metric in config.rank_metrics:
                add_rank_scores(
                    sample["masked_topk"]["topk"][:config.prefilter_topk],
                    metric["key"],
                    metric["name"]
                )

            for metric in config.gap_metrics:
                add_gap_scores(
                    sample["masked_topk"]["topk"][:config.prefilter_topk],
                    metric["key"],
                    metric["name"],
                )

            for metric in config.exposure_metrics:
                for prediction in sample["masked_topk"]["topk"][:config.prefilter_topk]:
                    prediction[metric["name"]] = exposure_score(
                        prediction,
                        metric["key"],
                        metric["exposure_topk"]
                    )

            all_samples.append(sample)

        db.close()

    def fn_rcauc(weights):
        confidence_metrics = [
            {"name": name, "weight": w}
            for (name, w) in zip(args.metrics_for_confidence, weights)
        ]
        rc_auc, risks, coverages, is_correct_all, confidence_scores_all = get_risk_coverage_auc(
            all_samples,
            args.metrics_for_ranking,
            confidence_metrics,
            disable_tqdm=True,
        )
        return rc_auc

    cons = ({"type": "eq", "fun": lambda w: 1 - sum(w)})
    bounds = [(0, 1)] * len(args.metrics_for_confidence)

    optimal_scores = []
    optimal_weights = []
    logger.info("Searching for optimal weights...")
    for i in tqdm(range(args.iter)):
        starting_values = np.random.uniform(size=(len(args.metrics_for_confidence)))
        res = minimize(
            fn_rcauc,
            starting_values,
            constraints=cons,
            bounds=bounds,
            method="SLSQP",
        )

        optimal_scores.append(res["fun"])
        optimal_weights.append(res["x"])

    s_w_pairs = [(l, w) for l, w in zip(optimal_scores, optimal_weights)]
    s_w_pairs = sorted(s_w_pairs, key=lambda x: x[0])

    for rcauc, weights in s_w_pairs:
        print(json.dumps({
            "RC-AUC": rcauc,
            "weights": dict([(name, weight) for name, weight in zip(args.metrics_for_confidence, weights)])
        }))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)

    parser = ArgumentParser()
    parser.add_argument("root_dir", help="Directory cotaining results")
    parser.add_argument(
        "--config",
        required=True,
        help="Path to config file"
    )
    parser.add_argument(
        "--glob",
        default="**/result.sqlite",
        help="glob expression for finding result files"
    )
    parser.add_argument(
        "--iter", type=int,
        default=100,
        help="Number of iteration of optimization"
    )
    args = parser.parse_args()

    main(args)
