def token_level_stat(train_data, tag2idx):
    correct = {}
    total_truth = {}
    total_label = {}
    idx2tag = {}

    keys = tag2idx.keys()
    for i, key in enumerate(keys):
        correct[key] = 0
        total_truth[key] = 0
        total_label[key] = 0
        idx2tag[i] = key

    with open(train_data, 'r', encoding='utf-8') as train:
        for line in train.readlines():
            line = line.strip()
            if len(line) > 0:
                cols = line.split(' ')
                true_label = cols[1] if 'O' == cols[1] else cols[1].split('-')[1]
                ds_label = int(cols[2])
                total_truth[true_label] += 1
                total_label[idx2tag[ds_label]] += 1

                if tag2idx[true_label] == ds_label:
                    correct[true_label] += 1

    print('ground-truth:\t\t', total_truth)
    print('distant labeling:\t', total_label)
    print('correct labels:\t\t', correct)
    print('###########')
    for key in keys:
        print(key + ': ')
        print('precision: {:.2f} '.format(correct[key] / total_label[key] * 100) + '(' + str(correct[key]) + '/' + str(total_label[key]) + ')' + ', '
              + 'recall: {:.2f} '.format(correct[key] / total_truth[key] * 100) + '(' + str(correct[key]) + '/' + str(total_truth[key]) + ')' +
              '\n')


def main():
    # # BC5CDR
    # train_data = '../data/BC5CDR_Dict_0.2/train.ALL.txt'
    # tag2idx = {"O": 0, "Chemical": 1, "Disease": 2}
    # token_level_stat(train_data, tag2idx)

    # CoNLL2003
    train_data = '../data/CoNLL2003_Dict_1.0/train.ALL.txt'
    tag2idx = {"O": 0, "PER": 1, "LOC": 2, "ORG": 3, "MISC": 4}
    token_level_stat(train_data, tag2idx)


if __name__ == '__main__':
    main()