#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: eval.py
#Created Time:
############################

import heapq
import numpy as np
def get_label_by_threshold(scores,threshold=0.5):
    predicted_onehot_labels = []
    for score in scores:
        count = 0
        onehot_labels_list = [0]*len(score)
        for index, predict_score in enumerate(score):
            if predict_score>=threshold:
                onehot_labels_list[index] = 1
                count += 1
        if count == 0:
            #score = np.ndarray.tolist(score)
            max_score_index = score.index(max(score))
            onehot_labels_list[max_score_index] = 1
        predicted_onehot_labels.append(onehot_labels_list)
    #return np.matrix(predicted_onehot_labels)
    return predicted_onehot_labels

def get_label_by_topk(scores,top_num=1):
    predicted_onehot_labels = []
    #scores = np.ndarray.tolist(scores.numpy())
    for score in scores:
        onehot_labels_list = [0]*len(score)
        max_num_index_list = list(map(score.index,heapq.nlargest(top_num,score)))
        for i in max_num_index_list:
            onehot_labels_list[i] = 1
        predicted_onehot_labels.append(onehot_labels_list)
    return np.array(predicted_onehot_labels)
