from stream_rules_base import (
    make_rule,
    conj_c_nodes,
    in_order_to_arg1)


def is_node_can_be_reached_from(nodes, from_nid, to_nid):
    visited = set()
    seeds = [from_nid]
    if from_nid == to_nid:
        return True
    while len(seeds) != 0:
        src = seeds.pop()
        for _, tar in nodes[src]['edges'].items():
            if tar == to_nid:
                return True
            if tar not in visited and tar != from_nid:
                seeds.append(tar)
                visited.add(tar)


@make_rule('pre')
def rule_preprocess_init(graph, removed_edges):
    nodes = graph['nodes']

    conj_c_nodes.clear()
    in_order_to_arg1.clear()

    for src in nodes:
        src_node = nodes[src]
        src_label = src_node['label']
        src_edges = src_node['edges']
        if src_label == '_in+order+to_x':
            arg1 = src_edges.get('ARG1')
            arg2 = src_edges.get('ARG2')
            if arg1 and arg2:
                in_order_to_arg1[arg1] = src, arg2
        elif src_node['pos'] == 'c' or src_node['label'] == 'implicit_conj':
            left_node = src_edges.get('L-INDEX')
            right_node = src_edges.get('R-INDEX')
            if not left_node:
                left_node = src_edges.get('L-HNDL')
            if not right_node:
                right_node = src_edges.get('R-HNDL')
            if left_node and right_node:
                conj_c_nodes[left_node] = left_node, src, right_node
                conj_c_nodes[right_node] = left_node, src, right_node
        elif src_label == 'focus_d':
            arg1 = src_edges.get('ARG1')
            arg2 = src_edges.get('ARG2')
            # @@: ignore ARG2 of ARG1 of focus, if v -> focus_d -> v
            # @@: ignore ARG2 of ARG1 of focus, if a_1 -> focus_d -> v
            if arg1 and arg2:
                arg1_node = nodes[arg1]
                # arg2_node = nodes[arg2]
                edges = arg1_node['edges']
                for elabel, to_nid in edges.items():
                    if not elabel.startswith('ARG') or elabel == 'ARG1':
                        continue
                    if is_node_can_be_reached_from(nodes, to_nid, arg2):
                        # 暂时删去这条边
                        removed_edges.append((to_nid, arg1))
                        del edges[elabel]
                        break

                # @@: if ARG2 of focus is top, make ARG1 of focus top
                if arg2 == graph['top']:
                    graph['top'] = arg1
        elif src_label == 'parg_d':
            # @@: parg_d 类似于 focus_d
            # edges = src_node['edges']
            # if 'ARG2' in edges:
            #     del edges['ARG2']
            pass

    return graph


@make_rule('pre')
def rule_preprocess_top(graph, delayed_seeds):
    nodes = graph['nodes']
    top = graph['top']

    top_node = nodes[top]
    top_label = top_node['label']
    # @@: change top to its only child "neg", "focus_d",
    # "eventuality"
    while top_label in ['neg', 'focus_d', 'eventuality']:
        new_top = top_node['edges'].get('ARG1')
        if new_top is not None:
            top = graph['top'] = new_top
            top_node = nodes[top]
            top_label = top_node['label']
        else:
            break

    # @@: change top to its ARG1, if it's a 'a', '*subord*' or 'p'
    # node and its ARG1 is a verb
    if top_node['pos'] in ['a', 'p'] or top_label.find('subord') != -1:
        new_top = top_node['edges'].get('ARG1')
        if new_top and nodes[new_top]['pos'] in ['v', 'x']:
            top = graph['top'] = new_top

    return graph


@make_rule('tar',
           # _and_c 结点分为两类
           # + `n and n' 作为主语时, verb 结点指向 and 结点 (这时不需要而外规则)
           # + `a and a' 作为形容词, 两个结点分别指向中心语
           #   (这时需要 R-INDEX -> and -> L-INDEX)
           lambda _, __, tar, nodes, visited: \
           tar in conj_c_nodes and \
           conj_c_nodes[tar][1] not in visited)
def rule_and_index_as_target(src, elabel, tar, nodes, stream):
    left_node, and_, right_node = conj_c_nodes[tar]
    stream[and_] = left_node
    stream[right_node] = and_
    if tar == left_node:
        stream[left_node] = src
        return [(False, left_node), (False, right_node), (False, and_)]
    else:
        and_node_edges = nodes[and_]['edges']
        if 'L-INDEX' in and_node_edges:
            del and_node_edges['L-INDEX']
        elif 'L-HNDL' in and_node_edges:
            del and_node_edges['L-HNDL']
        return [(False, right_node), (False, and_)]


@make_rule('tar',
           # @@: leaves should be delayed
           lambda _, __, tar, nodes, ___: len(nodes[tar]['edges']) == 0)
def rule_delayed_leaves(src, elabel, tar, nodes, stream):
    stream[tar] = src
    return [(True, tar)]


def focus_as_target_condition(_, __, tar, nodes, ___):
    # focus should have two ARGs
    arg1 = nodes[tar]['edges'].get('ARG1')
    arg2 = nodes[tar]['edges'].get('ARG2')
    return nodes[tar]['label'] == 'focus_d' and arg1 and arg1 != arg2


@make_rule('tar',
           # @@: rules for focus_d, ARG2 -> focus_d -> ARG1 or reversed
           focus_as_target_condition)
def rule_focus_as_target(src, elabel, tar, nodes, stream):
    result = [(False, tar)]
    stream[tar] = src
    arg2 = nodes[tar]['edges'].get('ARG2')
    arg1 = nodes[tar]['edges'].get('ARG1')

    if arg2 and arg1 == src:
        stream[arg2] = tar
        result.append((False, arg2))
        # # 延迟 ARG2
        # result.append((True, arg2))
    elif arg1 and arg2 == src:
        stream[tar] == src
        stream[arg1] = tar
        result.append((False, arg1))
    return result


@make_rule('tar',
           # @@: rules for parg_d
           lambda _, __, tar, nodes, ___: nodes[tar]['label'] == 'parg_d')
def rule_parg_as_target(src, elabel, tar, nodes, stream):
    result = [(False, tar)]
    stream[tar] = src

    arg2 = nodes[tar]['edges'].get('ARG2')
    arg1 = nodes[tar]['edges'].get('ARG1')
    if not arg1 or not arg2:
        return result
    if arg1 == src and (arg2 not in stream or
                        stream[arg2] == arg1):
        # 普通情况
        stream[arg2] = tar
        result.append((True, arg2))
    elif arg1 == src and stream.get(arg1) == arg2:
        # 从句情况
        stream[arg1] = tar
        stream[tar] = arg2

    return result


@make_rule('post')
def rule_postpreocess(nodes, stream):
    for nid in nodes:
        node = nodes[nid]
        if node['label'] == '_in+order+to_x':
            arg1 = node['edges'].get('ARG1')
            arg2 = node['edges'].get('ARG2')
            ref = stream.get(arg1)
            if ref and ref != nid and \
               ref == stream.get(arg2) and \
               stream.get(nid) != arg1 and \
               stream.get(nid) != arg2:
                stream[arg1] = stream[arg2] = nid
