#!/usr/bin/env python3

import os
import pickle

import pyshrg
from framework.common.logger import open_file
from pyshrg_utils.derivation import recover_subgraph_derivation
from pyshrg_utils.parser import PySHRGParserOptions, pyshrg_parse_args


class MainOptions(PySHRGParserOptions):
    sentence_id: str


def is_matched(subgraph, edge_set, eps):
    if subgraph is None:
        return edge_set is None
    if subgraph.edge_set.to_edge_indices() != edge_set:
        return False
    if eps is None:
        eps = []
    m = subgraph.node_mapping
    if len(m) > len(eps) and m[len(eps)] != 0:
        return False
    return eps == tuple(x - 1 for x in m[:len(eps)])


def split_item(context, subgraph, graph):
    center_part, left_ptr, right_ptr = context.split_item(subgraph, graph)
    center_edges = set(subgraph.edge_set.to_edge_indices())
    partition = (center_part,)
    if left_ptr is not None:
        center_edges -= set(left_ptr.edge_set.to_edge_indices())
        partition += (left_ptr.edge_set.to_node_indices(graph), )
    else:
        partition += (None, )
    if right_ptr is not None:
        center_edges -= set(right_ptr.edge_set.to_edge_indices())
        partition += (right_ptr.edge_set.to_node_indices(graph), )
    else:
        partition += (None, )

    return partition, center_edges, left_ptr, right_ptr


def dfs(manager, context, graph, ptrs, partitions, derivation, results, step):
    if step == -1:
        return True

    if ptrs[step] is None:
        return dfs(manager, context, graph, ptrs, partitions, derivation, results, step - 1)

    # print('Step:', step)
    ptr = ptrs[step]

    rule_index, left_index, right_index = derivation[step]
    gold_center, gold_left, gold_right, _ = partitions[step]

    current_ptr = ptr
    while True:
        print(current_ptr.grammar.cfg_at(0))
        if manager.get_shrg(rule_index) is current_ptr.grammar:
            pred_partition, center_edges, left_ptr, right_ptr = \
                split_item(context, current_ptr, graph)

            right_eps = partitions[right_index] if right_index >= 0 else None
            right_eps = right_eps[-1] if right_eps else None

            left_eps = partitions[left_index] if left_index >= 0 else None
            left_eps = left_eps[-1] if left_eps else None

            found = set(gold_center) == center_edges
            if found \
                    and is_matched(left_ptr, gold_right, right_eps) \
                    and is_matched(right_ptr, gold_left, left_eps):
                if left_index >= 0:
                    ptrs[left_index] = right_ptr
                if right_index >= 0:
                    ptrs[right_index] = left_ptr
            elif found \
                    and is_matched(left_ptr, gold_left, left_eps) \
                    and is_matched(right_ptr, gold_right, right_eps):
                if left_index >= 0:
                    ptrs[left_index] = left_ptr
                if right_index >= 0:
                    ptrs[right_index] = right_ptr

            if found:
                results[step] = rule_index, pred_partition
                if dfs(manager, context, graph, ptrs, partitions, derivation, results, step - 1):
                    return True

        current_ptr = current_ptr.next
        if current_ptr is ptr:
            break

    # print('Failed', current_ptr.grammar.label)
    return False


def main(argv=None):
    default_instance = MainOptions()
    default_instance.pyshrg.num_contexts = 1

    manager, options = pyshrg_parse_args(argv,
                                         default_instance=default_instance,
                                         abbrevs={'sentence_id': 'sentence_id'})
    context = manager.get_context(0)

    sentence_id = options.sentence_id
    grammar_dir = options.grammar_dir

    graph = next(graph for graph in manager.iter_graphs()
                 if graph.sentence_id.endswith(options.sentence_id))
    code = context.parse(graph)
    if code != pyshrg.ParserError.kNone:
        print(code)
        return
    derivations = pickle.load(open_file(os.path.join(grammar_dir, 'train.derivations.p'), 'rb'))
    derivation = derivations[sentence_id.rsplit('/')[-1]]
    partitions, derivation = recover_subgraph_derivation(derivation, graph)
    cpp_results = context.export_derivation(derivation, partitions, 10)

    py_results = [None] * len(partitions)
    ptrs = [None] * len(partitions)
    ptrs[-1] = context.result_item
    assert dfs(manager, context, graph, ptrs, partitions, derivation, py_results, len(ptrs) - 1)

    for step, (py_result, cpp_result) in enumerate(zip(py_results, cpp_results)):
        print(step, 'Python:', py_result[1][:-1] if py_result is not None else None)
        print(step, 'Cpp   :', cpp_result[1][0][1:][:-1] if cpp_result is not None else None)


if __name__ == '__main__':
    main()
