import torch
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report


@torch.no_grad()
def accuracy_fn(y_true, y_pred):
    """Micro P = Micro R = Micro F1 = Accuracy."""
    assert len(y_pred) == len(y_true)

    return f1_score(y_true, y_pred, average='micro', zero_division=0)

@torch.no_grad()
def precision_fn(y_true, y_pred):
    """Precision Score."""
    assert len(y_pred) == len(y_true)

    return precision_score(y_true, y_pred, average='macro', zero_division=0)

@torch.no_grad()
def recall_fn(y_true, y_pred, average='macro'):
    """Recall Score."""
    assert len(y_pred) == len(y_true)

    return recall_score(y_true, y_pred, average=average, zero_division=0)

@torch.no_grad()
def macro_f1_fn(y_true, y_pred):
    """F1 score."""
    assert len(y_pred) == len(y_true)

    return f1_score(y_true, y_pred, average='macro', zero_division=0)

@torch.no_grad()
def HM_fn(x, y, eps=1e-6):
    """Harmonic Mean."""
    return 2 * (x * y) / (x + y + eps)

@torch.no_grad()
def top_k_acc_fn(y_true, output, k=3):
    assert len(output) == len(y_true)
    correct = 0
    for i in range(k):
        correct += torch.sum(output[:, i] == y_true).item()
    return correct / len(y_true)

@torch.no_grad()
def seen_metric(y_true, y_pred, num_seen_class):
    res = {}
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        if str(i) not in report:
            c += 1
            continue
        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']

    res['seen_accuracy'] = accuracy_fn(y_true, y_pred)
    res['seen_precision'] = p / (num_seen_class - c)
    res['seen_recall'] = r / (num_seen_class - c)
    res['seen_macro_f1'] = f1 / (num_seen_class - c)
    return res


@torch.no_grad()
def unseen_metric(y_true, y_pred, num_seen_class):
    res = {}
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        if str(i) not in report:
            c += 1
            continue
        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']

    res['unseen_accuracy'] = accuracy_fn(y_true, y_pred)
    res['unseen_precision'] = p / (num_seen_class - c)
    res['unseen_recall'] = r / (num_seen_class - c)
    res['unseen_macro_f1'] = f1 / (num_seen_class - c)
    return res


@torch.no_grad()
def all_metric(y_true, pred, n_seen_test, num_seen_class, num_unseen_class):
    res = {}
    # only seen
    y = y_true[:n_seen_test]
    yp = pred[:n_seen_test, :num_seen_class].max(dim=1)[1]
    report = classification_report(y, yp, output_dict=True, zero_division=0)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        idx = str(i)
        if idx not in report:
            c += 1
            continue
        p += report[idx]['precision']
        r += report[idx]['recall']
        f1 += report[idx]['f1-score']
    # print(c)
    res['only_seen_accuracy'] = accuracy_fn(y, yp)
    res['only_seen_precision'] = p / (num_seen_class - c)
    res['only_seen_recall'] = r / (num_seen_class - c)
    res['only_seen_macro_f1'] = f1 / (num_seen_class - c)


    # only zsl
    y = y_true[n_seen_test:] - num_seen_class
    yp = pred[n_seen_test:, num_seen_class:].max(dim=1)[1]
    report = classification_report(y, yp, output_dict=True)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_unseen_class):
        idx = str(i)
        if idx not in report:
            c += 1
            continue
        p += report[idx]['precision']
        r += report[idx]['recall']
        f1 += report[idx]['f1-score']

    res['only_unseen_accuracy'] = accuracy_fn(y, yp)
    res['only_unseen_precision'] = p / (num_unseen_class - c)
    res['only_unseen_recall'] = r / (num_unseen_class - c)
    res['only_unseen_macro_f1'] = f1 / (num_unseen_class - c)

    # GZSL
    y_pred = pred.max(dim=1)[1]
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        idx = str(i)
        if idx not in report:
            c += 1
            continue
        p += report[idx]['precision']
        r += report[idx]['recall']
        f1 += report[idx]['f1-score']
    y = y_true[:n_seen_test]
    yp = y_pred[:n_seen_test]
    res['seen_accuracy'] = accuracy_fn(y, yp)
    res['seen_precision'] = p / (num_seen_class - c)
    res['seen_recall'] = r / (num_seen_class - c)
    res['seen_macro_f1'] = f1 / (num_seen_class - c)
    seen_recall = recall_fn(torch.ones_like(y),
                                         yp < num_seen_class,
                                         average='binary')

    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class, num_seen_class+num_unseen_class):
        idx = str(i)
        if idx not in report:
            c += 1
            # print(idx)
            continue
        p += report[idx]['precision']
        r += report[idx]['recall']
        f1 += report[idx]['f1-score']
    # print(c)
    y = y_true[n_seen_test:]
    yp = y_pred[n_seen_test:]
    # print(c)
    # print((y==yp).sum()/len(y))
    # print(p / num_unseen_class, precision_fn(y, yp))
    # print(r / num_unseen_class, recall_fn(y, yp))
    res['unseen_accuracy'] = accuracy_fn(y, yp)
    res['unseen_precision'] = p / num_unseen_class
    res['unseen_recall'] = r / num_unseen_class
    res['unseen_macro_f1'] = f1 / num_unseen_class
    unseen_recall = recall_fn(torch.ones_like(y),
                                         yp >= num_seen_class,
                                         average='binary')

    res['total_seen_recall'] = seen_recall
    res['total_unseen_recall'] = unseen_recall
    res['total_acc_hm'] = HM_fn(res['seen_accuracy'], res['unseen_accuracy'])
    res['total_f1_hm'] = HM_fn(res['seen_macro_f1'], res['unseen_macro_f1'])
    return res, report



@torch.no_grad()
def gzsl_metric(y_true, y, n_seen_test, num_seen_class, num_unseen_class):
    res = {}
    # only sup
    y_t = y_true[:n_seen_test]
    y_p = torch.argmax(y[:n_seen_test, :num_seen_class], dim=1)
    report = classification_report(y_t, y_p, output_dict=True)
    # print(report)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        if str(i) not in report:
            c += 1
            continue
        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']
    res['only_seen_accuracy'] = accuracy_fn(y_t, y_p)
    res['only_seen_precision'] = p / (num_seen_class - c)
    res['only_seen_recall'] = r / (num_seen_class - c)
    res['only_seen_macro_f1'] = f1 / (num_seen_class - c)

    # only zsl
    y_t = y_true[n_seen_test:] - num_seen_class
    y_p = torch.argmax(y[n_seen_test:, num_seen_class:], dim=1)
    report = classification_report(y_t, y_p, output_dict=True)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_unseen_class):
        if str(i) not in report:
            c += 1
            continue
        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']
    res['only_unseen_accuracy'] = accuracy_fn(y_t, y_p)
    res['only_unseen_precision'] = p / (num_unseen_class - c)
    res['only_unseen_recall'] = r / (num_unseen_class - c)
    res['only_unseen_macro_f1'] = f1 / (num_unseen_class - c)

    # gzsl
    y_pred = torch.argmax(y, dim=1)
    report = classification_report(y_true, y_pred, output_dict=True)
    p = 0.0
    r = 0.0
    f1 = 0.0
    c = 0
    for i in range(num_seen_class):
        if str(i) not in report:
            c += 1
            continue

        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']
    y = y_true[:n_seen_test]
    yp = y_pred[:n_seen_test]
    res['seen_accuracy'] = accuracy_fn(y, yp)
    res['seen_precision'] = p / (num_seen_class - c)
    res['seen_recall'] = r / (num_seen_class - c)
    res['seen_macro_f1'] = f1 / (num_seen_class - c)
    seen_recall = recall_fn(torch.ones_like(y),
                                         yp < num_seen_class,
                                         average='binary')

    p = 0.0
    r = 0.0
    f1 = 0.0
    for i in range(num_seen_class, num_seen_class+num_unseen_class):
        if str(i) not in report:
            continue
        p += report[str(i)]['precision']
        r += report[str(i)]['recall']
        f1 += report[str(i)]['f1-score']

    y = y_true[n_seen_test:]
    yp = y_pred[n_seen_test:]
    res['unseen_accuracy'] = accuracy_fn(y, yp)
    res['unseen_precision'] = p / num_unseen_class
    res['unseen_recall'] = r / num_unseen_class
    res['unseen_macro_f1'] = f1 / num_unseen_class
    unseen_recall = recall_fn(torch.ones_like(y),
                                         yp >= num_seen_class,
                                         average='binary')

    res['total_seen_recall'] = seen_recall
    res['total_unseen_recall'] = unseen_recall
    res['total_acc_hm'] = HM_fn(res['seen_accuracy'], res['unseen_accuracy'])
    res['total_f1_hm'] = HM_fn(res['seen_macro_f1'], res['unseen_macro_f1'])
    return res


if __name__ == '__main__':
    y_true = torch.tensor([1,2,3,4,5])
    y_pred = torch.tensor([4,1,2,4,1])

    print(top_k_acc_fn(y_true, y_pred))