import json
import os
import csv
import numpy as np
from response_selection.utils import recall_x_at_k, mrr


def main_script(dirname, num_candidates):
    dirlist = os.listdir(dirname)
    dirlist = sorted(
        [
            os.path.join(dirname, el)
            for el in dirlist
            if "performance" not in el
        ]
    )

    result = {}

    # variables for saving csv file
    models_w_recall = {}
    models_w_mrr = {}
    test_list = [
        "random",
        "human",
        "direct_wo_ans",
        "meta",
    ]

    for exp_dir in dirlist:
        if exp_dir not in result:
            result[exp_dir] = {}

        flist = os.listdir(exp_dir)
        flist = [
            os.path.join(exp_dir, fname)
            for fname in flist
            if ".json" in fname
        ]
        assert all([".json" in el for el in flist])
        flist = sorted(flist)

        for fname in flist:
            print(fname)
            if str(num_candidates) in fname:
                assert "candi" in fname
                assert "test" in fname or "dev" in fname

                train_test_pair = fname.split("-")
                train_d_type, test_d_type = (
                    train_test_pair[0].replace("./result/dd/", ""),
                    train_test_pair[1],
                )
                if (
                    "sep" in train_test_pair[-1]
                    or "ord" in train_test_pair[-1]
                    or "compact" in train_test_pair[-1]
                    or "ntex" in train_test_pair[-1]
                ):
                    train_d_type += "-" + train_test_pair[-1].replace(
                        ".json", ""
                    )
                if len(train_test_pair) == 4:
                    test_option = "-" + train_test_pair[-1].replace(
                        ".json", ""
                    )
                    train_d_type += test_option

                recall = main(fname, num_candidates)
                # recall10 = main(fname, num_candidates, 10)
                mrr = get_mrr(fname, num_candidates)

                if train_d_type not in models_w_recall:
                    models_w_recall[train_d_type] = [0] * len(test_list)
                    models_w_mrr[train_d_type] = [0] * len(test_list)

                for i, test in enumerate(test_list):
                    if test in test_d_type:
                        models_w_recall[train_d_type][i] = recall
                        models_w_mrr[train_d_type][i] = mrr

                result[exp_dir][fname] = {
                    "recall": recall,
                    "mrr": mrr,
                }

    print("result: ", result)

    with open(
        "result/performance/all_result_{}.csv".format(num_candidates),
        "a",
        newline="",
    ) as f_csv:
        wr = csv.writer(f_csv)
        wr.writerow(["models"] + test_list)
        for key in models_w_recall.keys():
            print(key)
            temp_list = [key] + models_w_recall[key]
            wr.writerow(temp_list)
        for key in models_w_mrr.keys():
            print(key)
            temp_list = [key] + models_w_mrr[key]
            wr.writerow(temp_list)
        f_csv.close()

    with open(
        "result/performance/dump_result_{}.json".format(num_candidates), "w"
    ) as f:
        json.dump(result, f, indent=2)


def softmax_np(logits):
    exp_logits = np.exp(logits)
    probs = exp_logits / np.sum(exp_logits)

    return [float(el) for el in probs]


def main(fname, num_candidates, k=1):
    assert ".json" in fname

    with open(fname, "r") as f:
        prediction_data = [
            json.loads(el) for el in f.readlines() if el.strip() != ""
        ]

    r10 = run_origianl_recall(prediction_data, num_candidates, k)

    return r10


def get_mrr(fname, x):
    with open(fname, "r") as f:
        prediction_data = [
            json.loads(el) for el in f.readlines() if el.strip() != ""
        ]
    score_list = []
    for item in prediction_data:
        scores = item["pred"][:x]
        score = mrr(scores, 0)
        score_list.append(score)
    return sum(score_list) / len(score_list)


def run_origianl_recall(prediction_list, x: int, k=1):
    """.

    Args:
        prediction_list (List[Dict[str,Union[List[float], bool]]]): {
                    "pred": [list of unnormalized score],
                    "uncertainty": [list of unnormalized uncertainty],
                    "is_uw": bool,
                }
    """
    Recall_list = []
    for item in prediction_list:
        uncertainty = item["uncertainty"][:x]
        prediction_outcome = item["pred"][:x]
        Recall_list.append(recall_x_at_k(prediction_outcome, x, k, 0))
    return sum(Recall_list) / len(Recall_list)


if __name__ == "__main__":
    num_candidates = 6
    dirname = "./result/"
    os.makedirs("./result/performance", exist_ok=True)
    main_script(dirname, num_candidates)
