import numpy as np
import os
import json
from collections import Counter, OrderedDict
import logging
logger = logging.getLogger(__name__)


def get_labels(path, name,  negative_label="no_relation"):
    """See base class."""

    count = Counter()
    with open(path + "/" + name, "r") as f:
        features = []
        for line in f.readlines():
            line = line.rstrip()
            if len(line) > 0:
                # count[line['relation']] += 1
                features.append(eval(line))
    return features


def main():

    dataset = 'dataset'
    data_names = ['semeval', 'retacred', 'tacred', 'tacrev']
    data_dir = data_names[1]
    data_file = 'test.txt'
    path = os.path.join(dataset, data_dir)

    n_sample = 1000
    n_least = 5
    seed = 10

    dataset = get_labels(path, data_file)

    # Other datasets
    np.random.seed(seed)
    np.random.shuffle(dataset)

    label_list = {}
    for line in dataset:
        label = line['relation']
        if label not in label_list:
            label_list[label] = [line]
        else:
            label_list[label].append(line)
    ratio = len(dataset)//n_sample

    samples_list = []
    with open(os.path.join(path, 'test_samples.txt'), "w") as f:
        file_list = []
        for label, label_data in label_list.items():
            l_num = len(label_data)
            l_n_least = min(l_num, n_least)
            k = max(l_num//ratio, l_n_least)
            for line in label_data[:k]:
                f.writelines(json.dumps(line))
                f.write('\n')


if __name__ == "__main__":
    main()


