def read_entity_prob_data(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as F:
        for line in F.readlines():
            if len(line.strip()) > 0:
                line = line.strip().split(' ')
                data.append(line)
    return data


def analyze_entity_prob_data(file_path, tag2idx, etas):
    tag_list = list(tag2idx.keys())
    idx_list = list(tag2idx.values())
    data = read_entity_prob_data(file_path)
    for eta in etas:
        stat = {}
        for key in tag2idx.keys():
            stat[key] = [0, 0]
        for line in data:
            dist_label = int(line[2])
            prob = float(line[3])
            idx = idx_list.index(dist_label)
            key = tag_list[idx]
            if prob >= eta:
                stat[key][0] += 1
            else:
                stat[key][1] += 1
        print('---------' + ' eta = ' + str(eta) + ' ---------')
        for key in stat.keys():
            print(key + '_above: {}, '.format(stat[key][0]), key + '_below: {}'.format(stat[key][1]))


def print_entity_prob_data(file_path):
    data = read_entity_prob_data(file_path)
    etas = [0.1]
    for eta in etas:
        unlabel_above_eta = 0
        unlabel_below_eta = 0
        PER_above_eta = 0
        PER_below_eta = 0
        LOC_above_eta = 0
        LOC_below_eta = 0
        ORG_above_eta = 0
        ORG_below_eta = 0
        MISC_above_eta = 0
        MISC_below_eta = 0
        for line in data:
            dist_label = float(line[2])
            prob = float(line[3])
            if dist_label == 0:
                if prob >= eta:
                    unlabel_above_eta += 1
                else:
                    unlabel_below_eta += 1
            if dist_label == 1:
                if prob >= eta:
                    PER_above_eta += 1
                else:
                    PER_below_eta += 1
            if dist_label == 2:
                if prob >= eta:
                    LOC_above_eta += 1
                else:
                    LOC_below_eta += 1
            if dist_label == 3:
                if prob >= eta:
                    ORG_above_eta += 1
                else:
                    ORG_below_eta += 1
            if dist_label == 4:
                if prob >= eta:
                    MISC_above_eta += 1
                else:
                    MISC_below_eta += 1
        print('---------' + ' eta = ' + str(eta) + ' ---------')
        print('unlabel_above_eta: ', unlabel_above_eta)
        print('unlabel_below_eta: ', unlabel_below_eta)
        print('PER_above_eta: ', PER_above_eta)
        print('PER_below_eta: ', PER_below_eta)
        print('LOC_above_eta: ', LOC_above_eta)
        print('LOC_below_eta: ', LOC_below_eta)
        print('ORG_above_eta: ', ORG_above_eta)
        print('ORG_below_eta: ', ORG_below_eta)
        print('MISC_above_eta: ', MISC_above_eta)
        print('MISC_below_eta: ', MISC_below_eta)


def main():
    etas = [0.5]

    # tag2idx = {"O": 0, "PER": 1, "LOC": 2, "ORG": 3, "MISC": 4}
    # tag2idx = {"O": 0, "PER": 1, "LOC": 2, "ORG": 3}

    tag2idx = {"O": 0, "Chemical": 1, "Disease": 2}

    file_path = '../data/BC5CDR_Dict_0.2/train.ALL.txt.entity_prob'
    analyze_entity_prob_data(file_path, tag2idx, etas)


if __name__ == '__main__':
    main()
