import numpy as np
import torch
import datasets
import smtplib

def compute_metrics(eval_preds):
    metric = datasets.load_metric("f1")
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    try:
        return metric.compute(predictions=preds, references=labels, average='macro')
    except:
        label_onehot = np.argmax(labels, axis=-1)
        return metric.compute(predictions=preds, references=label_onehot, average='macro')


def compute_metrics_custom(eval_preds):
    metric = datasets.load_metric("f1")
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    label_onehot = np.argmax(labels, axis=-1)
    return metric.compute(predictions=preds, references=label_onehot, average='macro')