import csv
import os
import numpy as np

RECALL_FILE = "recall_at_k.csv"
ENTITY_QUESTIONS_DATASETS = ["P36", "P407", "P26", "P159", "P276", "P40", "P176", "P20", "P112", "P127", "P19", "P740",
                             "P413", "P800", "P69", "P50", "P170", "P106", "P131", "P17", "P175", "P136", "P264", "P495"]


def get_recall_file_path(base_dir, model_name, setting, dataset, is_hybrid):
    eval_dir_name = "eval-hybrid" if is_hybrid else "eval"
    if setting == "zero-shot":
        return os.path.join(base_dir, model_name, eval_dir_name, dataset, RECALL_FILE)
    return os.path.join(base_dir, model_name, setting, eval_dir_name, dataset, RECALL_FILE)


def extract_recalls(file_path):
    print("Opening file", file_path)
    with open(file_path, "r") as f:
        reader = csv.reader(f)
        lines = [line for line in reader]
    recalls = []
    for k, (k_str, recall) in enumerate(lines):
        assert (k+1) == int(k_str)
        recalls.append(float(recall)*100)
    return recalls


def get_recalls_for_model(base_dir, model_name, setting, dataset, is_hybrid, ks):
    if "entityquestions" in dataset:
        dataset_suffix = "dev" if dataset.endswith("dev") else "test"
        all_recalls = []
        for dataset in ENTITY_QUESTIONS_DATASETS:
            dataset = f"{dataset}-{dataset_suffix}"
            recall_file_path = get_recall_file_path(base_dir, model_name, setting, dataset, is_hybrid)
            recalls = extract_recalls(recall_file_path)
            all_recalls.append(recalls)
        recall_k = np.array(all_recalls).mean(axis=0)
    else:
        recall_file_path = get_recall_file_path(base_dir, model_name, setting, dataset, is_hybrid)
        recall_k = extract_recalls(recall_file_path)

    recall_k = [recall_k[k-1] for k in ks]
    return recall_k
