import logging
from collections import deque
from copy import deepcopy

from reader_fix import is_prefix_node, is_normal_node, preprocess_edges
from extract_labels import spans_labels
from mark_edge_with_label import stream_edge_spans
from rule import (make_equations, make_rule,
                  print_equations, print_rule,
                  equations_lemma_pos)

EMPTY_STATE = 'EMPTY'
EMPTY_STATE_COUNT = 16


def is_empty_state(state):
    return state[0] < EMPTY_STATE_COUNT


def init_states():
    states = ['#' + str(i) for i in range(EMPTY_STATE_COUNT)]
    state2index = {}
    return states, state2index


def setdefault_states(state, states, state2index):
    if state in state2index:
        return state2index[state]
    elif state.startswith(EMPTY_STATE):
        for i in range(EMPTY_STATE_COUNT):
            if states[i][0] == '#':
                states[i] = state
                state2index.setdefault(state, i)
                return i
    else:
        state2index[state] = len(states)
        states.append(state)
        return len(states) - 1


def extract_rules(fn, uttr, graph, stream,
                  cfg_tree, cfg_leaves, states, state2index,
                  level=0):
    nodes = graph['nodes']

    def add_to_states(state):
        return setdefault_states(state, states, state2index)

    def visit_node(out_nid):
        in_count[out_nid] -= 1
        if in_count[out_nid] == 0:
            seeds.append(out_nid)

    edge_spans = {}
    for (f, t), spans in stream_edge_spans(fn, uttr, deepcopy(nodes), stream,
                                           cfg_tree, cfg_leaves).items():
        edge_spans[f, t] = edge_spans[t, f] = spans

    edge_states = {}

    in_map = {x: [] for x in nodes}
    out_map = {x: [] for x in nodes}
    in_count = {x: 0 for x in nodes}
    for src in nodes:
        src_node = nodes[src]
        edges = src_node['edges']
        preprocess_edges(edges)
        for elabel, tar in edges.items():
            in_count[tar] += 1
            in_map[tar].append((elabel, src))
            out_map[src].append((elabel, tar))

    seeds = deque(x for x in in_count if in_count[x] == 0)

    while len(seeds) != 0:
        nid = seeds.popleft()
        output_nid = stream.get(nid)
        node = nodes[nid]
        node_label = node['label']
        sense = node['sense']
        in_edges = []
        out_edges = []

        total_var_count = 0

        after_lemma_var = None
        for _, in_nid in in_map[nid]:
            # 入边不计入标记为 EMPTY 的边.
            state = edge_states[in_nid, nid]
            total_var_count += state[1]
            if state[0] < EMPTY_STATE_COUNT:
                continue
            in_edges.append((state,
                             edge_spans.get((in_nid, nid)),
                             in_nid == output_nid))
            in_node = nodes[in_nid]
            if in_node['label'] == 'neg' and \
               sense == 'modal' and \
               in_node['span'] == node['span']:
                after_lemma_var = '<NEG>'

        pos = node['pos']
        is_special = not node_label.startswith('_')

        # 顶点 label 也要化简
        if not is_special:
            new_label = '_' + (node['pos'] or '#')
            if sense:
                new_label += '_' + sense
        else:
            new_label = node_label

        # TODO: 特殊处理
        # ...

        # Default
        for elabel, out_nid in out_map[nid]:
            spans = edge_spans.get((nid, out_nid))
            label, var_count = EMPTY_STATE, 0  # (1, 0)
            if spans is not None:
                var_count = len(spans)

            # 顶点的 label
            if is_special:
                # udef_q, proper_q
                if pos == 'q':
                    label = 'DET'
                else:
                    label = new_label
            elif is_prefix_node(node_label):
                # 前缀
                label = 'prefix'
            elif spans is not None and len(spans) > 0:
                # NP_VP ...
                labels = spans_labels(cfg_tree, spans, True)
                if level // 3 == 2:
                    def merge_label(label):
                        pos = label.find('-')
                        if pos != -1:
                            label = label[:pos]
                        return label

                    label = '_'.join(merge_label(l) for l in labels)
                else:
                    label = '_'.join(labels)
            elif spans is not None:  # 特殊情况, 设为 EMPTY_LABEL
                logger = logging.getLogger('extract_rules')
                logger.debug('span-is-empty:%s', nodes[out_nid]['label'])
                label = 'NP'

            if (level // 3 >= 1) and var_count == 0:
                label = EMPTY_STATE

            if (level // 3 == 3) and \
                    label[0].isupper() and \
                    label != EMPTY_STATE:
                label = 'X'

            label += ':' + elabel

            index = add_to_states(label)
            state = (index, var_count)
            total_var_count += var_count

            out_edges.append((state, spans, output_nid == out_nid))
            edge_states[nid, out_nid] = state

            visit_node(out_nid)
        in_edges.sort(key=lambda x: x[0])
        out_edges.sort(key=lambda x: x[0])

        if not in_edges and not out_edges or \
           (total_var_count == 0 and graph['top'] != nid):
            equations = ()
        else:
            equations = tuple(make_equations(in_edges, out_edges,
                                             uttr, node, after_lemma_var))

        # make matching hard
        # if node.get('sense-used') and not is_special:
        #     new_label = pos + '_S'

        yield (make_rule(new_label,
                         map(lambda x: x[0], in_edges),
                         map(lambda x: x[0], out_edges)),
               tuple(equations),
               (nid, node))


def should_node_have_lemma_default(node):
    pred = node['label'] not in ['_and_c', '_per_p'] and \
        is_normal_node(node)
    return pred


def extract_all_rules(data, streams,
                      should_node_have_lemma=should_node_have_lemma_default,
                      level=0):
    logger = logging.getLogger('extract_rules')
    rules = {}
    states, state2index = init_states()
    for fn in data:
        uttr, cfg_tree, cfg_leaves, graph = data[fn]
        nodes = graph['nodes']
        stream = streams.get(fn)

        if not stream or len(stream) + 1 != len(nodes):
            continue
        for rule, eqs, (nid, node) in extract_rules(fn, uttr, graph, stream,
                                                    cfg_tree, cfg_leaves,
                                                    states, state2index,
                                                    level):
            if should_node_have_lemma(node) and \
               not equations_lemma_pos(eqs):
                span = node['span']
                logger.warning('%s:no-lemma: %s ("%s")',
                               fn, node['label'], uttr[span[0]:span[1]])

            rule_dict = rules.setdefault(rule, {})
            item = rule_dict.get(eqs)
            if item:
                rule_dict[eqs] = fn, item[1] + 1
            else:
                rule_dict[eqs] = fn, 1

        if len(rules) > 40000:
            logging.warning('Too many rules')
            break
    return rules, states


if __name__ == '__main__':
    import sys
    from data_stream import load_stream
    from reader import load_data

    logging.basicConfig(level=logging.INFO)

    logger = logging.getLogger('extract_rules')
    level = 3
    if len(sys.argv) > 1:
        try:
            level = int(sys.argv[1].strip())
        except Exception:
            pass
    logger.info('level = %s', level)

    train_data = load_data()[0]
    streams = load_stream(train_data)

    rules, states = extract_all_rules(train_data, streams, level=level)

    print('rules:', len(rules))

    with open('../data/tagged-rules.txt', 'w') as out:
        for rule, rule_dict in rules.items():
            print_rule(rule, states, out)
            for eqs, (fn, count) in rule_dict.items():
                print('Count:', count, 'File:', fn, file=out)
                print_equations(eqs, out)
            out.write('\n')
