import numpy as np
from lib.icd10 import Idc10


def eval_single(golden, pred, parser: Idc10, level=1):
    golden = parser.get_des_by_level(parser.get_diseases(golden), level)
    golden = set(golden)

    pred = parser.get_des_by_level(
        parser.get_diseases(pred[:len(golden)]), level)
    pred = set(pred)

    sim_matrix = np.zeros((len(pred), len(golden)))
    for i, p in enumerate(pred):
        for j, g in enumerate(golden):
            if p is None or g is None:
                sim_matrix[i, j] = 0
            else:
                sim_matrix[i, j] = int(p == g)

    tp = np.sum((np.sum(sim_matrix, axis=1) > 0).astype(int))
    fp = len(pred) - tp
    assert fp >= 0, print(golden, pred, sim_matrix, tp, fp, fn)
    fn = len(golden) - np.sum((np.sum(sim_matrix, axis=1) > 0).astype(int))

    return np.array([tp, fp, fn])


def micro_average(counts):
    true_p, false_p, false_n = counts
    p = true_p / (true_p + false_p + 1e-10)
    r = true_p / (true_p + false_n + 1e-10)
    f = 2 * p * r / (p + r) if p + r != 0 else 0
    return "micro average: p = {:.3f}, r = {:.3f}, f = {:.3f}".format(p, r, f)


def eval_batch(goldens, preds, parser, level=2):
    result = np.array([0.0, 0.0, 0.0])

    for golden, pred in zip(goldens, preds):
        result += eval_single(golden, pred, parser, level)

    return micro_average(result)