import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import re

def evaluate(preds, labels, mode):
    gold_label = {"CT+": 0, "CT-": 1, "PS+": 2, "PS-": 3, "Uu": 4}
    gold_label_pair = {"CT": [0, 1], "PS": [2, 3], "p": [0, 2], "n": [1, 3]}

    if mode in gold_label:
        tp = sum([1 for p, l in zip(preds, labels) if p == gold_label[mode] and l == gold_label[mode]])
        fp = sum([1 for p, l in zip(preds, labels) if p == gold_label[mode] and l != gold_label[mode]])
        fn = sum([1 for p, l in zip(preds, labels) if p != gold_label[mode] and l == gold_label[mode]]) 
    elif mode in gold_label_pair:
        tp = sum([1 for p, l in zip(preds, labels) if p in gold_label_pair[mode] and l in gold_label_pair[mode]])
        fp = sum([1 for p, l in zip(preds, labels) if p in gold_label_pair[mode] and l not in gold_label_pair[mode]])
        fn = sum([1 for p, l in zip(preds, labels) if p not in gold_label_pair[mode] and l in gold_label_pair[mode]])
    else:
        raise ValueError("Invalid evaluation mode")

    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
    return precision, recall, f1


def postprocess(input_path):
    with open(input_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]
    count = 0
    pattern = r'(CT\+|CT-|PS\+|PS-|Uu)'
    for d in data:
        matches = re.findall(pattern, d["output"])
        if len(matches) == 0:
            d["pred"] = "Uu"
            count += 1
        else:
            d["pred"] = matches[0]
    preds = [d["pred"] for d in data]
    labels = [d["label"] for d in data]
    label2id = {"CT+": 0, "CT-": 1, "PS+": 2, "PS-": 3, "Uu": 4}
    preds = [label2id[p] for p in preds]
    labels = [label2id[l] for l in labels]
    preds = np.array(preds)
    labels = np.array(labels)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="macro")
    CTp_precision, CTp_recall, CTp_f1 = evaluate(preds, labels, "CT+")
    CTn_precision, CTn_recall, CTn_f1 = evaluate(preds, labels, "CT-")
    PSp_precision, PSp_recall, PSp_f1 = evaluate(preds, labels, "PS+")
    PSn_precision, PSn_recall, PSn_f1 = evaluate(preds, labels, "PS-")
    Uu_precision, Uu_recall, Uu_f1 = evaluate(preds, labels, "Uu")
    res = {
            "CTp_precision": CTp_precision,
            "CTp_recall": CTp_recall,
            "CTp_f1": CTp_f1,
            "CTn_precision": CTn_precision,
            "CTn_recall": CTn_recall,
            "CTn_f1": CTn_f1,
            "PSp_precision": PSp_precision,
            "PSp_recall": PSp_recall,
            "PSp_f1": PSp_f1,
            "PSn_precision": PSn_precision,
            "PSn_recall": PSn_recall,
            "PSn_f1": PSn_f1,
            "Uu_precision": Uu_precision,
            "Uu_recall": Uu_recall,
            "Uu_f1": Uu_f1,
            "acc": acc,
            "macro_f1": f1,
    }
    for k, v in res.items():
        print(f"{k}: {v}")

    return res

if __name__ == "__main__":
    input_path = "output/result.jsonl"
    postprocess(input_path)