import logging
from collections import deque
from itertools import chain
from reader_fix import preprocess_edges

conj_c_nodes = {}
in_order_to_arg1 = {}
parg_d_arg1 = {}

pre_rules = []
as_source_rules = []
as_target_rules = []
post_rules = []


def make_rule(tp, pred=None):
    if tp == 'post':
        rules = post_rules
    elif tp == 'pre':
        rules = pre_rules
    else:
        rules = as_target_rules

    def inner(fn):
        if pred:
            rules.append((pred, fn))
        else:
            rules.append(fn)
        return fn
    return inner


def nodes_can_be_reached_from(nodes, from_nid):
    visited = set()
    seeds = deque([from_nid])
    while len(seeds) != 0:
        src = seeds.popleft()
        # 由于 conj_c 类结点数据由 ARG1 流处, 所以认为 left_node -> right_node
        if src in conj_c_nodes:
            left_node, _, right_node = conj_c_nodes[src]
            if left_node == src and \
               right_node not in visited and right_node != from_nid:
                seeds.append(right_node)
                visited.add(right_node)
        # _in+order+to_x 结点类似
        if src in in_order_to_arg1:
            _, right_node = in_order_to_arg1[src]
            if right_node not in visited and right_node != from_nid:
                seeds.append(right_node)
                visited.add(right_node)
        for _, tar in nodes[src]['edges'].items():
            if tar not in visited and tar != from_nid:
                seeds.append(tar)
                visited.add(tar)
    return visited


def select_toplevel_nodes(nodes, ns):
    '''
    选出顶层的结点
    '''
    connection = set(
        chain(*[nodes_can_be_reached_from(nodes, n) for _, n in ns]))
    return [(el, n) for el, n in ns if n not in connection]


def is_node_connection_point(nodes, in_map, node, visited):
    '''
    出边是连通桥
    '''
    seeds = deque(nodes[node]['edges'].values())
    reached = set(seeds)
    while len(seeds) != 0:
        src = seeds.popleft()
        if src in visited:
            return False
        for _, tar in chain(nodes[src]['edges'].items(), in_map[src]):
            if tar != node and tar not in reached:
                reached.add(tar)
                seeds.append(tar)
    return True


def should_node_be_delayed(node):
    return node['pos'] in ['c', 'n'] or node['label'] == 'nominalization'


def rule_default_as_source(src, nodes, tars, stream, visited):
    result = []
    for elabel, tar in chain(tars[0] + tars[1]):
        if tar not in visited:
            found = False
            for test, as_tar_rule in as_target_rules:
                if test(src, elabel, tar, nodes, visited):
                    found = True
                    result.extend(as_tar_rule(src, elabel, tar,
                                              nodes, stream))
                    break
            if not found:
                result.append((False, tar))
                if tar not in stream:
                    stream[tar] = src
    return result


def get_data_stream(graph):
    logger = logging.getLogger('stream_rules_base')
    nodes = graph['nodes']
    in_map = {x: [] for x in nodes}

    stream = {}
    current_delayed_nodes = set()

    removed_edges = deque([])
    for pre_rule in pre_rules:
        graph = pre_rule(graph, removed_edges)

    top = graph['top']
    seeds = deque([top])
    delayed_seeds = deque([])
    visited = set(seeds)

    # @@: edge with no data stream carry empty variables
    for src in nodes:
        src_node = nodes[src]
        edges = src_node['edges']
        # 移除所有的平行边
        preprocess_edges(edges)
        # @@: skip some edges, 'R-HNDL', 'L-HNDL'
        if 'R-INDEX' in edges and 'R-HNDL' in edges:
            # 可能导致不连通, 要延后考虑
            # if edges.get('R-INDEX') != edges.get('R-HNDL'):
            removed_edges.append((edges['R-HNDL'], src))
            del edges['R-HNDL']
        if 'L-INDEX' in edges and 'L-HNDL' in edges:
            # if edges.get('L-INDEX') != edges.get('L-HNDL'):
            removed_edges.append((edges['L-HNDL'], src))
            del edges['L-HNDL']
        for elabel, tar in edges.items():
            in_map[tar].append((elabel, src))

    logger.debug('removed_edges: %s',
                 [(nodes[f]['label'], nodes[t]['label'])
                  for f, t in removed_edges])

    def add_to_seeds(ss):
        for is_delayed, seed in ss:
            if seed not in visited:
                visited.add(seed)
                if is_delayed:
                    delayed_seeds.append(seed)
                else:
                    seeds.append(seed)

    while True:
        while len(seeds) != 0 or len(delayed_seeds) != 0:
            while len(seeds) != 0:
                src = seeds.popleft()
                src_node = nodes[src]
                tars = sorted(src_node['edges'].items()), sorted(in_map[src])
                found = False
                # is_conj = src_node['pos'] == 'c'
                # is_noun = src_node['pos'] == 'n'
                should_be_delayed = should_node_be_delayed(src_node)
                # _in+order+to_x 结点的 arg2 应该被延迟
                # @@ data_stream.6 @@
                if src_node['label'] == '_in+order+to_x':
                    arg2 = src_node['edges'].get('ARG2')
                    arg1 = src_node['edges'].get('ARG1')

                    assert(len(src_node['edges']) <= 2)

                    if arg2 and arg1 and \
                       arg2 not in visited and \
                       arg1 not in visited:
                        # 忽略其他出边
                        stream[arg1] = src
                        stream[arg2] = src
                        # arg2 还是 in order to 不应该 delayed
                        arg2_label = nodes[arg2]['label'] != '_in+order+to_x'
                        arg2_should_be_delayed = arg2_label
                        visited.add(arg1)
                        visited.add(arg2)
                        seeds.append(arg1)
                        if arg2_should_be_delayed:
                            # 优先级略高一点
                            delayed_seeds.appendleft(arg2)
                        else:
                            seeds.append(arg2)
                        # 入边继续处理
                        tars = ([], tars[1])
                # 对于被延迟的名词性结点, 常常带从句, 所以要找子句 top
                # 目前是 没有出边的 和 被延迟的 conj_c 结点
                # TODO: 统计的方法
                elif len(tars[0]) == 0 or \
                        should_be_delayed and src in current_delayed_nodes:
                    # 所有动词结点优先

                    v_in_nodes = [(el, n) for el, n in tars[1]
                                  if nodes[n]['pos'] in ['v', 'u']
                                  if n not in visited]

                    if len(v_in_nodes) == 0:
                        v_in_nodes = [(el, n) for el, n in tars[1]
                                      if n not in visited]

                    if len(v_in_nodes) > 0:
                        # 先遍历拓扑排序最高的几个 v 结点
                        selected = select_toplevel_nodes(nodes, v_in_nodes)

                        assert(len(selected) > 0)

                        found = True
                        add_to_seeds(rule_default_as_source(
                            src, nodes, ([], selected),
                            stream, visited))
                        # 再次延迟遍历
                        delayed_seeds.append(src)
                # 连词结点充当整体的名词性结点, @see 21162012.gz, 20765040.gz
                elif should_be_delayed and \
                        is_node_connection_point(nodes, in_map, src, visited):
                    found = True
                    delayed_seeds.append(src)
                if not found:
                    add_to_seeds(rule_default_as_source(
                        src, nodes, tars, stream, visited))
            current_delayed_nodes.clear()
            current_delayed_nodes.update(delayed_seeds)
            seeds.extend(delayed_seeds)
            delayed_seeds.clear()
        if len(removed_edges) == 0:
            break
        # 由于边移除导致孤立的顶点
        for nid, out_nid in removed_edges:
            if nid not in visited:
                visited.add(nid)
                seeds.append(nid)
                stream[nid] = out_nid
        removed_edges.clear()

    for post_rule in post_rules:
        post_rule(nodes, stream)
    return stream
