import os
import logging
import pickle
from copy import deepcopy

from reader import load_data
from stream_rules_base import get_data_stream
import stream_rules

STREAM_FILES = 'data/stream.pki'


def check_stream(stream, nodes):
    if len(nodes) != len(stream) + 1:
        return False

    in_count = {x: 0 for x in nodes}
    for src in stream:
        in_count[stream[src]] += 1

    seeds = [x for x in nodes if in_count[x] == 0]
    visited_count = 0

    while len(seeds) != 0:
        src = seeds.pop()
        tar = stream.get(src)
        visited_count += 1
        if tar:
            in_count[tar] -= 1
            if in_count[tar] == 0:
                seeds.append(tar)

    return visited_count == len(nodes)


def load_stream(data, output_file=STREAM_FILES):
    if not os.path.exists(output_file):
        logger = logging.getLogger('data_stream')
        logger.warning('stream data file not exists, create one')
        dump_all_streams(data, output_file)

    return pickle.load(open(output_file, 'rb'))


def dump_all_streams(data, filename):
    streams = {}
    logger = logging.getLogger('data_stream')

    for index, fn in enumerate(data):
        eds_graph = data[fn][3]
        stream = get_data_stream(deepcopy(eds_graph))
        if not check_stream(stream, eds_graph['nodes']):
            logger.error('%s:invalid-stream', fn)
            continue
        streams[fn] = stream
        if index % 1000 == 0:
            logger.info('%d finished', index)
    pickle.dump(streams, open(filename, 'wb'))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    train_data = load_data()[0]
    dump_all_streams(train_data, STREAM_FILES)
