import logging
import os
import re
from collections import deque
from itertools import chain

from data_stream import load_stream
from reader import load_data
from reader_fix import is_normal_node, regex_not_word
from span import spans_add, spans_flip, is_span_overlap_span

# 1. 利用 CFG derivation tree 和 EDS 对齐
#
# 2. 之前得到的 data-stream 形成一棵倒过来 spanning tree. 每条边上的关
#    联一个 span-list. 按照一定规则, 自底向上传播合并结点 span-list
#
#    - 每个结点的 data-stream 的出边 span-list 是所有入边的合并
#
#    - 当是普通结点 (`_' 开头) 或在 tree 里有对应的结点, 自己带的 span
#      也对出边的 span-list 有贡献
#
#    - 需要添加必要的 missing-spans (指的是 tree 里有 EDS 里没有的)
#
# 3. 此时每条边的 span-list 的大小应该和导出的规则中边带的变量数一致
#    (变量解出来之后就应该是这个 span), 变量合并的顺序按自然顺序
#

THAN_SPAN = -1
STRICT = False  # 为 True 时, 一些情况报错

special_labels_should_not_be_aligned = [
    'udef_q', 'proper_q', 'pronoun_q',
    'def_explicit_q', 'def_implicit_q', 'number_q',
    'part_of', 'generic_entity'
]


def get_aligned_nodes(nodes, cfg_leaves):
    '''
    找到 EDS 中在 CFG tree 有对应的结点
    '''
    span_set = {}
    aligned_nodes = {}
    for nid in nodes:
        node = nodes[nid]
        beg, end = node['span']
        if is_normal_node(node):
            index = -1
            # Find index
            for i, leaf in enumerate(cfg_leaves):
                if beg >= leaf.beg and end <= leaf.end:
                    index = i
                    break
            # 记录当前结点和出边数
            aligned_nodes[nid] = index
            span_set[beg, end] = nid, -1, index

    for nid in nodes:
        node = nodes[nid]
        beg, end = node['span']
        # 已找到的结点不是普通结点
        if is_normal_node(node) or \
           node['label'] in special_labels_should_not_be_aligned:
            continue

        # 如果和其他有交叠, 不记录
        if any(is_span_overlap_span(s, (beg, end)) for s in span_set):
            continue

        record = span_set.get((beg, end))
        num = len(node['edges'])
        if record:
            old_nid, old_num, old_index = record
            # 如果新结点出边少, 覆盖
            if num <= old_num:
                span_set[beg, end] = nid, num, old_index
        else:
            for index, leaf in enumerate(cfg_leaves):
                if beg >= leaf.beg and end <= leaf.end:
                    span_set[beg, end] = nid, num, index
                    break
    aligned_nodes.update({nid: index for nid, _, index in span_set.values()})
    return aligned_nodes, span_set


def get_missing_spans(uttr, nodes,
                      aligned_nodes, special_nodes, span_set,
                      cfg_tree, cfg_leaves):
    logger = logging.getLogger('mark_edge_with_label')
    # 收集所有的对齐结点的 span, 得到覆盖区域
    covered_spans = []
    part_of_nodes = special_nodes['part_of']
    for nid, index in aligned_nodes.items():
        node = nodes[nid]
        span = node['span']

        if index < len(cfg_leaves) - 1 and index >= 0:
            leaf = cfg_leaves[index]
            next_leaf = cfg_leaves[index + 1]
            # 标点直接合并
            pred = next_leaf.parent and next_leaf.parent.value == 'pnct'
            pred = pred and \
                next_leaf.first_cross()[0] == leaf.first_cross()[0]
            # 介词合并 (有时候介词带有标点)
            sense = node['sense']
            beg, end = next_leaf.beg, next_leaf.end
            # 需要保持大小写, 去掉一些特殊情况
            value = re.sub(regex_not_word, '', uttr[beg:end])
            # 保证介词是补齐的
            pred = pred or (beg, end) not in span_set and \
                sense and value != '' and sense.find(value) != -1
            if pred:
                span = span[0], end
                node['span'] = span
        covered_spans = spans_add(covered_spans, span)

    # 没有覆盖到的区域 (需要特殊规则生成)
    missing_spans = []

    for span in spans_flip(covered_spans, 0, len(uttr)):
        ns = cfg_tree.get_span_cfg_nodes(span[0], span[1])
        neighbor_span = None
        string = uttr[span[0]:span[1]]
        # 'than' 应该在 comp 处补全, 留到后面解决
        if len(ns) == 0 or string == '':
            # logger.warning('%s span aligned to no cfg leaves', span)
            pass
        elif string == 'than':
            neighbor_span = THAN_SPAN
        elif len(ns) == 1:  # Fix
            # 's 这样的 span 需要根据 cfg 去更新
            if span[1] - span[0] < ns[0].end - ns[0].beg:
                span = ns[0].beg, ns[0].end

            # 处理介词和动词分开的情况
            main_beg = None
            value = re.sub(regex_not_word, '', uttr[span[0]:span[1]])
            for nid in chain(aligned_nodes, part_of_nodes):
                node = nodes[nid]
                sense = node['sense']
                # 在 sense 中找到了缺的词
                if node['span'][1] < span[0] and \
                   sense and value != '' and sense.find(value) != -1:
                    main_beg = node['span'][0]
                    break

            p, n = ns[0].first_cross()
            if p:
                p2, n2 = p.first_cross()
                # 向上的第二个分岔存在, 而且路径方向不变
                # 默认 2 叉
                if p2 and n.index == 0 and n2.index == 0:
                    #    / \
                    #  / \  y2
                    # x   y1
                    # x 应该和 y1 合并
                    neighbor_span = (n.end + 1, p.end)
                elif p2 and n.index == 1 and n2.index == 1:
                    #   / \
                    # y1  / \
                    #   y2   x
                    # x 应该和 y2 合并
                    neighbor_span = (p.beg, n.beg - 1)
                elif main_beg is not None:
                    if n.index == 1 and p.beg >= main_beg:
                        # 这时 v_* 结点在左边的树
                        neighbor_span = (p.beg, n.beg - 1)
                elif p2 and n.index == 0 and n2.index == 1:
                    #   / \
                    # y1  / \
                    #    x   y2
                    # x 应该不能和 y2 合并
                    pass
                elif n.index == 1:
                    #    / \
                    #  / \  y2    / \
                    # y1  x      y1  x
                    # x 应该和 y1 合并
                    neighbor_span = (p.beg, n.beg - 1)
                elif n.index == 0:
                    # 只有一层
                    neighbor_span = (n.end + 1, p.end)
        else:
            logger.debug('multi-missing-spans:%s %s',
                         span, uttr[span[0]:span[1]])
        missing_spans.append((span, neighbor_span))

    logger.debug('missing_spans %s', missing_spans)
    return missing_spans


def stream_edge_spans(fn, uttr, nodes, stream, cfg_tree, cfg_leaves):
    '''
    合并 span, 给每一条数据流的边, 标上 span
    '''
    logger = logging.getLogger('mark_edge_with_label')
    # 入边集合
    in_edges = {edge: set() for edge in nodes}
    counter = {edge: 0 for edge in nodes}

    special_nodes = {'part_of': []}

    for nid in nodes:
        to_nid = stream.get(nid)
        if to_nid:
            in_edges[to_nid].add(nid)
            counter[to_nid] += 1

        # 记录特殊的点
        node_label = nodes[nid]['label']
        if node_label in special_nodes:
            special_nodes[node_label].append(nid)

    edge_spans = {}
    aligned_nodes, span_set = get_aligned_nodes(nodes, cfg_leaves)
    missing_spans = get_missing_spans(uttr, nodes,
                                      aligned_nodes, special_nodes, span_set,
                                      cfg_tree, cfg_leaves)
    logger.debug('aligned_nodes:%s',
                 [nodes[n]['label'] for n in aligned_nodes])
    # 所有的根结点
    roots = deque(edge for edge in nodes if len(in_edges[edge]) == 0)
    while len(roots) != 0:
        root = roots.popleft()
        next_node = stream.get(root)
        # 到达数据流最终结点
        if not next_node:
            continue

        spans = []
        span = nodes[root]['span']
        root_label = nodes[root]['label']
        # 有对应的结点, 应该考虑它自身的 span
        if root in aligned_nodes:
            spans.append(span)
        # 合并所有数据流进入的边的 span
        for in_node in in_edges[root]:
            prev_spans = edge_spans.get((in_node, root))
            if prev_spans is not None:
                for span in prev_spans:
                    spans = spans_add(spans, span)
        # 添加 missing_span
        if len(spans) > 0:
            for span, neighbor_span in missing_spans:
                # 在内部
                pred = span[0] >= spans[0][0] and span[1] <= spans[-1][1]
                # 和 neighbor_span 匹配, 在临界的时候也可以添加
                pred = pred or \
                    isinstance(neighbor_span, tuple) and \
                    neighbor_span[0] >= spans[0][0] and \
                    neighbor_span[1] <= spans[-1][1]
                # comp 结点可以尝试添加 than
                pred = pred or \
                    neighbor_span == THAN_SPAN and \
                    root_label == 'comp' or root_label.startswith('comp_')
                if pred:
                    new_spans = spans_add(spans, span)
                    # 保证添加后, span 数量不会增多
                    if len(new_spans) <= len(spans):
                        spans = new_spans

        if STRICT and len(spans) > 3:
            raise Exception('%s:large-spans-count:%d' % (fn, len(spans)))
        elif len(spans) > 2:
            logger.warning('%s:large-spans-count:%s %s',
                           fn, root_label, len(spans))

        edge_spans[root, next_node] = spans
        counter[next_node] -= 1
        # 所有前趋结点都遍历
        if counter[next_node] == 0:
            roots.append(next_node)

    return edge_spans


def draw_all_graphs(output_dir='data/rule-test/'):
    os.makedirs(output_dir, exist_ok=True)

    logger = logging.getLogger('mark_edge_with_label')
    train_data = load_data()[0]
    streams = load_stream(train_data)

    for fn in train_data:
        uttr, cfg_tree, cfg_leaves, graph = train_data[fn]
        nodes = graph['nodes']
        stream = streams[fn]

        fn = os.path.join(output_dir, str(len(nodes)) + '-' + fn + '.svg')
        if len(stream) + 1 != len(nodes):
            logger.warning('%s:not-connected', fn)

            # 删去不连通的结点
            nids = set(stream.keys())
            nids.update(stream.values())
            nodes = {n: nodes[n] for n in nids}

        edge_spans = stream_edge_spans(fn, uttr, nodes, stream,
                                       cfg_tree, cfg_leaves)
        edge_cfg_nodes = {edge: [cfg_tree.get_span_cfg_nodes(span[0], span[1])
                                 for span in spans]
                          for edge, spans in edge_spans.items()}

        with open(fn, 'wb') as out:
            out.write(draw_complex_graph(uttr, graph, stream,
                                         edge_spans,
                                         edge_cfg_nodes).create(format='svg'))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    draw_all_graphs()
