import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from logging import getLogger

logger = getLogger(__name__)

def accuracy_7(out, labels):
    return np.sum(np.round(out) == np.round(labels)) / float(len(labels))

def accuracy_3(out, labels):
    return np.sum((
        ((np.round(out) <= -1) & (np.round(labels) <= -1)) |
        ((np.round(out) >= 1) & (np.round(labels) >= 1)) |
        ((np.round(out) == 0) & (np.round(labels) == 0))
    ) / float(len(labels)))

def calc_result_classify(predict_lists, truth_lists, do_log=True):
    max_labels = np.round(np.max(np.array(predict_lists), -1), 2)
    predict_lists = np.argmax(np.array(predict_lists), -1)
    truth_lists = np.argmax(np.array(truth_lists), -1)
    f_score = f1_score(truth_lists, predict_lists, average='weighted')
    f_score_macro = f1_score(truth_lists, predict_lists, average='macro')
    print("predicts:")
    print(predict_lists[0:100])
    print("truth:")
    print(truth_lists[0:100])
    print("max_scores:")
    print(max_labels[0:100])
    acc = accuracy_score(truth_lists, predict_lists)
    result = {
        'acc': acc,
        'F1': f_score,
        'F1(Macro)': f_score_macro,
    }
    if do_log:
        logger.info("***** result *****")
        for key, val in sorted(result.items()):
            logger.info("  %s = %s", key, str(val))
    return result


def calc_result_multi(predict_lists, truth_lists, label_names, do_log=True):
    results = {}

    predict_lists = np.array(predict_lists).T
    truth_lists = np.array(truth_lists).T
    for i, name in enumerate(label_names):
        predict_list = predict_lists[i]
        truth_list = truth_lists[i]
        predict_list = np.array(predict_list).reshape(-1)
        truth_list = np.array(truth_list)
        no_neg = (truth_list >= 0).all()
        predict_list1 = (predict_list.round() >= 1) if no_neg else (predict_list > 0)
        truth_list1 = (truth_list.round() >= 1) if no_neg else (truth_list > 0)
        test_preds_a7 = np.clip(predict_list, a_min=-3., a_max=3.)
        test_truth_a7 = np.clip(truth_list, a_min=-3., a_max=3.)
        acc7 = accuracy_7(test_preds_a7,test_truth_a7)
        acc3 = accuracy_3(test_preds_a7,test_truth_a7)
        f_score = f1_score(predict_list1, truth_list1, average='weighted')
        f_score_7 = f1_score(test_preds_a7.round(), test_truth_a7.round(), average='weighted')
        f_score_7_macro = f1_score(test_preds_a7.round(), test_truth_a7.round(), average='macro')
        acc = accuracy_score(truth_list1, predict_list1)
        corr = np.corrcoef(predict_list, truth_list)[0][1]
        mae = np.mean(np.absolute(predict_list - truth_list))
        result = {
            'acc':acc,
            'F1':f_score,
            'mae':mae,
            'corr':corr,
            'acc3': acc3,
            'acc7':acc7,
            'F1(acc7)':f_score_7,
            'F1(macro, acc7)':f_score_7_macro,
        }
        if do_log:
            logger.info("***** %s result *****", name)
            logger.info(f"all positive value? {no_neg}")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
        results[name] = result

    return results