import logging

from reader_fix import preprocess_edges
from extract_rules import extract_rules, init_states
from rule import print_equations, print_rule


def compile_equations(eqs, literals, literal2index):
    equations = []
    for eq in eqs:
        equation = []
        for var in eq:
            if isinstance(var, str):
                index = literal2index.get(var)
                if index is None:
                    index = len(literals)
                    literal2index[var] = index
                    literals.append(var)
                var = - index - 32
            elif isinstance(var, tuple):
                var = (var[0] << 4) | var[1]

            equation.append(var)
        equations.append(tuple(equation))
    return tuple(equations)


def compile_states(states):
    new_state = []
    for s in states:
        new_state.append((s[0] << 4) | s[1])
    return tuple(new_state)


def print_compiled_rules(rule, eqs, out):
    node_label, in_states, out_states = rule
    print(node_label, file=out, end=' ')
    print(len(in_states), *in_states, file=out, end=' ')
    print(len(out_states), *out_states, file=out)
    print(len(eqs), file=out)
    for eq in eqs:
        print(len(eq), *eq, file=out)


def process_sentences(data):
    for fn, (uttr, _, cfg_leaves, graph) in data.items():
        graph['sentence'] = uttr


def write_graph_with_rules(graph, out):
    node2index = {graph['top']: 0}
    node_list = [graph['top']]
    nodes = graph['nodes']
    for nid in nodes:
        if nid not in node2index:
            index = len(node_list)
            node_list.append(nid)
            node2index[nid] = index

    print(graph.get('sentence', ''), file=out)
    # print(graph.get('processed_sentence', ''), file=out)
    print(len(node_list), file=out)
    for nid in node_list:
        node = nodes[nid]
        label = node['label']
        if label[0] == '_':
            ls = label.split('_')
            label = '_' + '_'.join(ls[2:])
            lemma = ls[1]
        else:
            lemma = '#'
        print(label,
              lemma,
              ord(node.get('pos') or '#'),
              node.get('sense') or '#',
              node.get('carg') or '#',
              *node.get('properties', ['#']*5),
              end=' ', file=out)

        edges = node['edges']
        preprocess_edges(edges)

        print(len(edges), end=' ', file=out)
        for elabel, target in edges.items():
            print(elabel, node2index[target], end=' ', file=out)
        print(node.get('rule', -1), file=out)
    print(file=out)


def extract_compiled_rules(data, streams,
                           states=None,
                           state2index=None,
                           literals=None,
                           literal2index=None,
                           level=0):
    logger = logging.getLogger('extract_compiled_rules')
    rules = {}
    rule_list = []
    # 'EMPTY' 代表 ignore
    if states is None:
        states, state2index = init_states()
    if literals is None:
        literals = []
        literal2index = {}

    for fn in data:
        uttr, cfg_tree, cfg_leaves, graph = data[fn]
        nodes = graph['nodes']
        stream = streams.get(fn)

        if not stream or len(stream) + 1 != len(nodes):
            logger.warning('%s:invalid-stream', fn)
            continue

        rule_iter = extract_rules(fn, uttr, graph, stream,
                                  cfg_tree, cfg_leaves,
                                  states, state2index,
                                  level)
        try:
            for rule, eqs, (_, node) in rule_iter:
                node_label, in_states, out_states = rule
                new_rule = (node_label,
                            compile_states(in_states),
                            compile_states(out_states))
                new_eqs = compile_equations(eqs, literals, literal2index)

                eqs2index = rules.setdefault(rule, {})
                if eqs not in eqs2index:
                    eqs2index[eqs] = len(rule_list)
                    rule_list.append(((rule, eqs),
                                      (new_rule, new_eqs)))

                node['rule'] = eqs2index[eqs]
        except Exception as err:
            logger.error('%s:%s', fn, err)

    logger.info('%d rules in total', len(rule_list))

    return rule_list, states, literals


def write_graphs(filename, data,
                 with_count=True, select=None, single=False, sort=False):
    if select:
        fns = list(fn for fn in data if select(data[fn][3]))
    else:
        fns = list(data.keys())
    if sort:
        fns.sort()
    if single:
        fns = [fns[0]]
    with open(filename, 'w') as out:
        if with_count:
            out.write(str(len(fns)))
        out.write('\n')
        for fn in fns:
            out.write(fn)
            out.write('\n')
            write_graph_with_rules(data[fn][3], out)


def write_rules(rule_list, states, rules_file, compiled_rules_file):
    with open(rules_file, 'w') as out:
        for index, ((rule, eqs), _) in enumerate(rule_list):
            out.write(str(index) + '\n')
            print_rule(rule, states, out)
            print_equations(eqs, out)
            out.write('\n')

    with open(compiled_rules_file, 'w') as out:
        out.write(str(len(rule_list)))
        out.write('\n')
        for _, (rule, eqs) in rule_list:
            print_compiled_rules(rule, eqs, out)
            out.write('\n')


if __name__ == '__main__':
    import sys
    from data_stream import load_stream
    from reader import load_data

    logging.basicConfig(level=logging.INFO)

    level = 3  # default level
    if len(sys.argv) > 1:
        try:
            level = int(sys.argv[1].strip())
        except Exception:
            pass
    print('INFO:level =', level)
    preload_states = False
    if len(sys.argv) > 2:
        preload_states = True
    print('INFO:preload states =', preload_states)

    train_data, dev_data, test_data = load_data()
    streams = load_stream(train_data)

    process_sentences(train_data)
    process_sentences(dev_data)
    process_sentences(test_data)

    if preload_states:
        from rules_reader import load_tokens, STATES_FILE, LITERALS_FILE
        states, state2index = load_tokens(STATES_FILE)
        literals, literal2index = load_tokens(LITERALS_FILE)
        rule_list, states, literals = \
            extract_compiled_rules(train_data, streams,
                                   states=states,
                                   state2index=state2index,
                                   literals=literals,
                                   literal2index=literal2index,
                                   level=level)
    else:
        rule_list, states, literals = \
            extract_compiled_rules(train_data, streams, level=level)

    write_rules(rule_list, states,
                'data/rules.txt', 'data/compiled-rules.txt')

    write_graphs('data/train-graphs.txt', train_data)
    write_graphs('data/dev-graphs.txt', dev_data, sort=True)
    write_graphs('data/test-graphs.txt', test_data,
                 with_count=False, sort=True)
    write_graphs('data/test-graph.txt', test_data, with_count=False,
                 single=True,
                 select=lambda x: len(x['nodes']) >= 30)
    write_graphs('data/train-small-graphs.txt', train_data,
                 select=lambda x: len(x['nodes']) <= 15)
    write_graphs('data/train-graph.txt', train_data,
                 single=True,
                 select=lambda x: len(x['nodes']) >= 30)

    with open('data/states.txt', 'w') as out:
        for state in states:
            out.write(state)
            out.write('\n')

    with open('data/literals.txt', 'w') as out:
        for literal in literals:
            out.write(literal)
            out.write('\n')
