import traceback
from collections import defaultdict

from coli.hrgguru.hyper_graph import HyperGraph
from coli.hrgguru.mrsguru.solve_mrs import MRSResolver, to_string, MRSResolverError, Timeout, MRSCheckerError
from delphin.mrs import simplemrs
from local_scripts.read_mrs import read_simple_graphs_as_hg, dmrs_hg_to_lfrg


def read_mrs_file(file_path, delete_empty_holes=False):
    with open(file_path) as f:
        content = f.read()
    sentence = content.strip().split("\n\n")
    for sent in sentence:
        mrs = simplemrs.loads_one(sent)
        if delete_empty_holes:
            labels = set(i.label for i in mrs.eps())
            qeq_his = {i.hi: i.lo for i in mrs.hcons()}
            for ep in mrs.eps():
                for key, value in list(ep.args.items()):
                    if key != "BODY" and value.startswith("h") and value not in labels \
                            and (value not in qeq_his or qeq_his[value] not in labels):
                        ep.args.pop(key)
        yield mrs


def delete_duplicate_q(hg, sent_id):
    pred_node_to_edge = {}
    q_s = {}
    pred_edges_remove = set()
    edges = set()
    for edge in hg.edges:
        if len(edge.nodes) == 1:
            pred_node_to_edge[edge.nodes[0]] = edge
    for edge in hg.edges:
        if len(edge.nodes) == 2:
            start_pred_edge = pred_node_to_edge[edge.nodes[0]]
            if start_pred_edge.label.endswith("_q"):
                if edge.nodes[1] not in q_s:
                    q_s[edge.nodes[1]] = (start_pred_edge, edge)
                else:
                    pred_edges_remove.add(start_pred_edge)
    for edge in hg.edges:
        if len(edge.nodes) == 1:
            if edge not in pred_edges_remove:
                edges.add(edge)
        else:
            if not any(pred_node_to_edge[i] in pred_edges_remove
                       for i in edge.nodes):
                edges.add(edge)

    return HyperGraph(frozenset(node for edge in edges
                                for node in edge.nodes),
                      frozenset(edges))


def solve_parsed_dmrs(file_name):
    hgs = read_simple_graphs_as_hg(file_name)
    total_count = success_count = 0
    for sent_id, dmrs_hg in hgs:
        dmrs_hg = delete_duplicate_q(dmrs_hg, sent_id)
        total_count += 1
        try:
            mrs_obj = dmrs_hg_to_lfrg(dmrs_hg)
            placeholder_id = 10000
            for ep in mrs_obj.eps():
                if "RSTR" in ep.args:
                    ep.args["BODY"] = f"h{placeholder_id}"
                    placeholder_id += 1
            resolver = MRSResolver(mrs_obj, greedy_simple_quantifier=True)
            # with Timeout(seconds=30):
            resolved_mrs = resolver.solve()
            try:
                resolver.check_correctness(resolved_mrs)
            except MRSCheckerError as e:
                print(f"{sent_id} Error!")
                print(to_string(resolved_mrs))
                traceback.print_exc()
            success_count += 1
        except MRSResolverError as e:
            print(f"Cannot solve {sent_id}!!")
            print(to_string(e.value))
        except TimeoutError:
            print(sent_id)
            print(f"{sent_id} Timeout!")
        except Exception as e:
            traceback.print_exc()
            print(f"{sent_id} Error!")
    print(success_count / total_count)


def solve_parsed_mrs(file_name, delete_empty_holes=False):
    mrs_iter = read_mrs_file(file_name, delete_empty_holes=delete_empty_holes)
    total_count = success_count = 0
    for mrs_obj in mrs_iter:
        total_count += 1
        try:
            resolver = MRSResolver(mrs_obj, greedy_simple_quantifier=True)
            # with Timeout(seconds=30):
            resolved_mrs = resolver.solve()
            # print(to_string(resolved_mrs))
            resolver.check_correctness(resolved_mrs)
            success_count += 1
        except MRSResolverError as e:
            print("Cannot solve this!!")
            print(to_string(e.value))
        except TimeoutError:
            print("Timeout!")
        except Exception as e:
            traceback.print_exc()
            print("Error!")
    print(success_count / total_count)


if __name__ == '__main__':
    import sys
    solve_parsed_mrs(sys.argv[1], True)


