import logging

from mark_edge_with_label import stream_edge_spans

DEFAULT_LABELS_FILE = 'data/labels.txt'


def spans_labels(cfg_tree, spans, clean=False):
    for span in spans:
        cfg_nodes = cfg_tree.get_span_cfg_nodes(span[0], span[1])
        if clean and len(cfg_nodes) > 1:
            label = 'X'
        else:
            label = '+'.join(n.value.upper() for n in cfg_nodes)
        if label == '':
            label = 'X'
        elif label == 'N':
            label = 'NP'
        elif label == 'V':
            label = 'VP'
        elif label == 'P':
            label = 'PP'
        yield label


def extract_labels(fn, uttr, nodes, stream, cfg_tree, cfg_leaves):
    edge_spans = stream_edge_spans(fn, uttr, nodes, stream,
                                   cfg_tree, cfg_leaves)
    # 状态的第一部分
    for edge, spans in edge_spans.items():
        yield from spans_labels(cfg_tree, spans)


def extract_all_labels(data, streams):
    labels = {}
    for fn in data:
        uttr, cfg_tree, cfg_leaves, graph = data[fn]
        nodes = graph['nodes']
        stream = streams[fn]

        if len(stream) + 1 != len(nodes):
            continue

        for label in extract_labels(fn, uttr, nodes, stream,
                                    cfg_tree, cfg_leaves):
            labels.setdefault(label, 0)
            labels[label] += 1
    return labels


if __name__ == '__main__':
    from data_stream import load_stream
    from reader import load_data

    logging.basicConfig(level=logging.INFO)
    train_data = load_data()[0]
    streams = load_stream(train_data)
    labels = extract_all_labels(train_data, streams)
    with open(DEFAULT_LABELS_FILE, 'w') as out:
        for label in sorted(labels, key=lambda x: -labels[x]):
            count = labels[label]
            if label != '':
                print(label, count, file=out)
