import re

import jsonlines
from lib.eval_diagnose import eval_batch

from lib.icd10 import Idc10


def extract_icd10_codes(text):
    pattern = r'[A-Z]\d{2}(?:\.\d{1,3})?(?:-[A-Z]\d{1,2})?'
    icd10_codes = re.findall(pattern, text)
    return icd10_codes


def eval(golden_file, pred_file):
    goldens = []
    with jsonlines.open(golden_file) as reader:
        for sample in reader:
            goldens.append(sample['icd-10'])

    preds = []
    with jsonlines.open(pred_file) as reader:
        for _, pred in enumerate(reader):
            pred = extract_icd10_codes(pred['result'])
            preds.append(pred)

    icd10 = Idc10(
        "resources/ICD-10.json",
        "resources/ICD-10-reverse.json",
        "resources/ICD-10.xlsx",
    )

    print(eval_batch(goldens, preds, icd10, level=0))
    print(eval_batch(goldens, preds, icd10, level=1))
    print(eval_batch(goldens, preds, icd10, level=2))


if __name__ == '__main__':
    eval(
        'testset/diagnose.jsonl',
        'your output pathl'
    )