import gzip
import operator
import os
import signal
import traceback
from collections import defaultdict, OrderedDict, Counter
from functools import reduce
from itertools import chain
from pathlib import Path
from typing import Dict, List, Union, Any

from dataclasses import dataclass
from wrapt import ObjectProxy

from coli.basic_tools.common_utils import AttrDict
from coli.hrgguru.feature_tracker import UFSet
from coli.hrgguru.mrsguru.hungary import maximum_match
from delphin.mrs import simplemrs, ElementaryPredication
from coli.span.const_tree import Lexicon, ConstTree


class MRSResolverError(Exception):
    def __init__(self, value):
        self.value = value


class MRSCheckerError(Exception):
    def __init__(self, value):
        self.value = value


class MRSCycleError(Exception):
    def __init__(self, value):
        self.value = value


def is_internal_ep(ep):
    return any(target_label.startswith("h") for target_label in ep.args.values())


def is_quantifier(ep):
    return "RSTR" in ep.args and "BODY" in ep.args


def encode_ep(ep: ElementaryPredication):
    args_field = ",".join(value for key, value in ep.args.items()
                          if key != "ARG0" and not value.startswith("h"))
    if args_field:
        args_field = f"[{args_field}]"
    self_var = ep.args.get("ARG0")
    return f"{ep.label}:{ep.pred}[{self_var}]{args_field}"


@dataclass
class FailureInfo(object):
    root_eps: List[ElementaryPredication]
    rest_ep_groups: Dict[str, List[ElementaryPredication]]
    qeq_lookup: Any
    eq_constraint: Any
    qeq_constraints: Any
    info: Any = ""

    def selected_ep_to_const_tree(self, ep: ElementaryPredication):
        sub_tree = ConstTree(encode_ep(ep))
        for name, value in ep.args.items():
            if not value.startswith("h"):
                continue
            sub_sub_tree = ConstTree(name)
            if value not in self.qeq_lookup:
                sub_sub_tree.children.append(Lexicon(value))
                sub_tree.children.append(sub_sub_tree)
            else:
                tree_0 = ConstTree(name)
                tree_1 = ConstTree(value)
                tree_2 = ConstTree("QEQ")
                tree_2.children.append(Lexicon(self.qeq_lookup[value]))
                tree_1.children.append(tree_2)
                tree_0.children.append(tree_1)
                sub_tree.children.append(tree_0)
        if len(sub_tree.children) == 0:
            sub_tree.children.append(Lexicon("T"))
        return sub_tree

    def to_const_tree(self):
        info = self.info.replace(" ", "_").replace("(", "_").replace(")", "_")
        tree = ConstTree(f'__FAILED__[{self.eq_constraint}][{",".join(self.qeq_constraints)}],'
                         f'{info}')
        selected = ConstTree("SELECTED")
        tree.children.append(selected)
        if len(self.root_eps) == 1:
            selected.children.append(self.selected_ep_to_const_tree(self.root_eps[0]))
        elif len(self.root_eps) == 0:
            selected.children.append(Lexicon("T"))
        else:
            union_node = ConstTree("^")
            selected.children.append(union_node)
            for root_ep in self.root_eps:
                union_node.children.append(self.selected_ep_to_const_tree(root_ep))
        for key, values in self.rest_ep_groups.items():
            sub_tree = ConstTree(f"GROUP:{key}")
            if len(values) == 1:
                sub_tree.children.append(self.selected_ep_to_const_tree(list(values)[0]))
            else:
                union_node = ConstTree("^")
                sub_tree.children.append(union_node)
                for value in values:
                    union_node.children.append(self.selected_ep_to_const_tree(value))
                    # sub_sub_tree = ConstTree(encode_ep(value))
                    # sub_sub_tree.children.append(Lexicon("T"))
                    # union_node.children.append(sub_sub_tree)
            tree.children.append(sub_tree)
        return tree

    def collect_eps(self, collector):
        collector.extend(self.root_eps)
        for eps in self.rest_ep_groups.values():
            collector.extend(eps)
        return collector


class ResolvedMRS(object):
    def __init__(self, ep, children: Dict[str, Union["ResolvedMRS", List["ResolvedMRS"], FailureInfo]] = None):
        self.ep: ElementaryPredication = ep
        if children is None:
            children = OrderedDict()
        self.children = children

    def to_const_tree(self):
        tree = ConstTree(encode_ep(self.ep))
        if not self.children:
            tree.children.append(Lexicon("T"))
        else:
            for name, child in self.children.items():
                sub_tree = ConstTree(f"{name}:{self.ep.args[name]}")
                tree.children.append(sub_tree)
                if isinstance(child, list):
                    union_node = ConstTree("^")
                    sub_tree.children.append(union_node)
                    for child_one in child:
                        union_node.children.append(child_one.to_const_tree())
                else:
                    sub_tree.children.append(child.to_const_tree())
        return tree

    def collect_eps(self, collector=None):
        if collector is None:
            collector = []
        collector.append(self.ep)
        for name, child in self.children.items():
            if isinstance(child, list):
                for child_one in child:
                    child_one.collect_eps(collector)
            else:
                child.collect_eps(collector)
        return collector


def to_string(inputs: Union["ResolvedMRS", List["ResolvedMRS"], FailureInfo]):
    if isinstance(inputs, list):
        tree = ConstTree("^")
        for one_input in inputs:
            tree.children.append(one_input.to_const_tree())
        return tree.to_parathesis()
    else:
        return inputs.to_const_tree().to_parathesis()


class IdentityWrapper(ObjectProxy):
    def __init__(self, obj, id):
        super(IdentityWrapper, self).__init__(obj)
        self.__hash = hash(id) + 1

    def __hash__(self):
        return self.__hash

    def __repr__(self):
        return "(HashWrap): " + repr(self.__wrapped__)


class MRSResolver(object):
    def __init__(self, mrs_obj, greedy_simple_quantifier=True):
        self.mrs_obj = mrs_obj
        self.greedy_simple_quantifier = greedy_simple_quantifier
        self.eps_wrapped = set()
        self.lbl_lookup = defaultdict(list)
        for idx, ep in enumerate(mrs_obj.eps()):
            ep_wrapped = IdentityWrapper(ep, idx)
            self.lbl_lookup[ep.label].append(ep_wrapped)
            self.eps_wrapped.add(ep_wrapped)
        self.qeq_lookup = {i.hi: i.lo for i in mrs_obj.hcons()
                           if i.lo != "u" and i.lo in self.lbl_lookup}

        # child <--ARGX--QEQ-- parent
        self.qeq_parents = defaultdict(list)
        self.eq_parents = {}

        self.internal_eps = set()
        self.terminal_eps = set()
        for ep in self.eps_wrapped:
            if is_internal_ep(ep):
                self.internal_eps.add(ep)
                for key, value in ep.args.items():
                    if value in self.qeq_lookup:
                        self.qeq_parents[self.qeq_lookup[value]].append(ep)
                    if value in self.lbl_lookup:
                        assert value not in self.eq_parents or self.eq_parents[value] is ep
                        self.eq_parents[value] = ep
            else:
                self.terminal_eps.add(ep)

    def solve(self):
        top_label = self.mrs_obj.top if self.mrs_obj.top in self.lbl_lookup else None
        top_qeq_target = self.qeq_lookup.get(self.mrs_obj.top)
        qeq_constraints = {top_qeq_target} if top_qeq_target is not None else set()

        total_hole_count = sum(1 for ep in self.internal_eps for args in ep.args.values()
                               if args.startswith("h"))
        label_count = len(set(i.label for i in (self.internal_eps | self.terminal_eps)))
        if total_hole_count + 1 > label_count:
            print("Too many holes!")

        try:
            ret = self.solve_inner(self.internal_eps, self.terminal_eps, top_label,
                                   qeq_constraints)
            if total_hole_count + 1 > label_count:
                print("Too many holes but we solve it!!")
            return ret
        except:
            if total_hole_count + 1 <= label_count:
                print(f"holes: {total_hole_count + 1}, labels: {label_count} but can't solve")
            raise

    def solve_inner(self,
                    internal_eps,
                    terminal_eps,
                    eq_constraint,
                    qeq_constraints,
                    ) -> Union["ResolvedMRS", List["ResolvedMRS"]]:
        qeq_parents = self.qeq_parents
        eq_parents = self.eq_parents
        lbl_lookup = self.lbl_lookup  # label -> list of EP
        qeq_lookup = self.qeq_lookup

        if len(internal_eps) == 0:
            if len(terminal_eps) == 1:
                ret = ResolvedMRS(list(terminal_eps)[0])
            elif len(terminal_eps) >= 1:
                ret = [ResolvedMRS(i) for i in terminal_eps]
            else:
                raise AssertionError("No terminal eps!!")
            return ret

        scope_vars = {ep.args["ARG0"]: ep for ep in internal_eps
                      if is_quantifier(ep)
                      }

        assert len(scope_vars.values()) == len(scope_vars.keys())

        sub_info = []
        best_idx = -1
        failed_structure = None
        cycled = True

        # select a handle from internal eps. label -> list of EP
        corresponding_internal_eps = defaultdict(list)

        # if total_hole_count == 1 and total_hole_count + 1 < len(label_count):
        #     if eq_constraint:
        #         new_terminal_ep = {i for i in terminal_eps}
        #         new_ter = []
        #     failed_structure = FailureInfo(
        #             [], {"SPECIAL": internal_eps | terminal_eps}, self.qeq_lookup,
        #             eq_constraint, qeq_constraints)
        #     raise MRSResolverError(failed_structure)

        if eq_constraint is not None:
            corresponding_internal_eps[eq_constraint] = [i for i in lbl_lookup[eq_constraint]
                                                         if is_internal_ep(i)]
            for key, value in corresponding_internal_eps.items():
                if len(value) == 0:
                    raise MRSResolverError(FailureInfo(
                        [], {"UNGROUPED": internal_eps | terminal_eps}, self.qeq_lookup,
                        eq_constraint, qeq_constraints))
        elif qeq_constraints:
            for ep in internal_eps:
                # its qeq target or can propagate QEQ
                if is_quantifier(ep) or (len(qeq_constraints) == 1 and ep.label == list(qeq_constraints)[0]):
                    corresponding_internal_eps[ep.label] = list(internal_eps & set(lbl_lookup[ep.label]))
        else:
            for ep in internal_eps:
                corresponding_internal_eps[ep.label].append(ep)

        if not corresponding_internal_eps:
            if eq_constraint is not None:
                target_terminal_eps = set(i for i in terminal_eps
                                          if i.label == eq_constraint)
            else:
                target_terminal_eps = set(i for i in terminal_eps
                                          if i.label in qeq_constraints)

            ret = [ResolvedMRS(list(target_terminal_eps)[0])]
            conjunction_tree = self.solve_inner(internal_eps, terminal_eps - target_terminal_eps,
                                                None, set())
            if isinstance(conjunction_tree, list):
                ret.extend(conjunction_tree)
            else:
                ret.append(conjunction_tree)
            return ret

        for selected_ep_label, selected_eps in corresponding_internal_eps.items():
            # skip if its need by some EP
            if selected_ep_label in qeq_parents and any(i in internal_eps for i in qeq_parents[selected_ep_label]) or \
                    selected_ep_label in eq_parents and eq_parents[selected_ep_label] in internal_eps:
                continue

            try:
                for selected_ep in lbl_lookup[selected_ep_label]:
                    for k, v in selected_ep.args.items():
                        if (k != "ARG0" or not is_quantifier(selected_ep)) and v in scope_vars:
                            raise AssertionError
            except AssertionError:
                continue

            cycled = False

            selected_scope_vars = {i.args["ARG0"] for i in selected_eps if is_quantifier(i)}
            selected_labels = set(i.label for i in selected_eps)
            new_qeq_constraints = qeq_constraints - selected_labels

            rest_internal_eps = internal_eps - set(selected_eps)
            rest_terminal_eps = set(i for i in terminal_eps if i.label != selected_ep_label)

            ufset = UFSet.from_leafs(
                chain((i.label for i in rest_internal_eps),
                      (i.label for i in rest_terminal_eps))
            )

            # scope constrait
            for ep in chain(rest_internal_eps, rest_terminal_eps):
                for arg_key, arg_value in ep.args.items():
                    if arg_value not in selected_scope_vars and arg_value in scope_vars:
                        # this arg must in scope
                        ufset.union(scope_vars[arg_value].label, ep.label)

            # QEQ and its parents must in the same group
            for child_label, parents in qeq_parents.items():
                if child_label != "u":
                    for parent in parents:
                        if (parent in rest_internal_eps or parent in rest_terminal_eps) \
                                and (child_label in ufset.roots):
                            ufset.union(parent.label, child_label)

            # EQ and its parent must in the same group
            for child_label, parent in eq_parents.items():
                if (parent in rest_internal_eps or parent in rest_terminal_eps) \
                        and (child_label in ufset.roots):
                    ufset.union(parent.label, child_label)

            groups = defaultdict(list)

            def get_group_eps():
                group_eps = defaultdict(list)
                for ep in internal_eps | terminal_eps:
                    if ep.label != selected_ep_label:
                        group_eps[ufset[ep.label]].append(ep)
                return group_eps

            # groups that have scope constraint
            group_to_scope_var = {}
            # groups that should be QEQ to
            group_qeq_constraint = set()
            eq_reverse_lookup = {}
            qeq_reverse_lookup = {}

            for node in ufset.roots.keys():
                root = ufset[node]
                groups[root].append(node)
                if node in new_qeq_constraints:
                    group_qeq_constraint.add(root)
                for ep in lbl_lookup[node]:
                    for var in ep.args.values():
                        if var in selected_scope_vars:
                            group_to_scope_var[root] = var

            child_info = OrderedDict()
            try:
                for selected_ep in selected_eps:
                    for name, value in sorted(selected_ep.args.items()):
                        if value in qeq_lookup:
                            group_name = ufset[qeq_lookup[value]]
                            child_info[selected_ep, name] = AttrDict(type="QEQ",
                                                                     group=group_name,
                                                                     target=qeq_lookup[value])
                            if group_name in qeq_reverse_lookup:
                                raise MRSCycleError(("QEQ", qeq_reverse_lookup[group_name], (selected_ep, name)))
                            qeq_reverse_lookup[group_name] = (selected_ep, name)
                        elif value in lbl_lookup:
                            group_name = ufset[value]
                            child_info[selected_ep, name] = AttrDict(type="EQ", group=group_name)
                            if group_name in eq_reverse_lookup:
                                raise MRSCycleError(("EQ", eq_reverse_lookup[group_name], (selected_ep, name)))
                            eq_reverse_lookup[group_name] = (selected_ep, name)
                        else:
                            # is normal variable
                            continue
            except MRSCycleError as e:
                failed_structure = FailureInfo(selected_eps, get_group_eps(), self.qeq_lookup,
                                               eq_constraint, qeq_constraints, info=str(e.value))
                continue

            is_simple_quantifier = len(selected_eps) == 1 and (selected_eps[0], "RSTR") in child_info and \
                                   len(groups[child_info[selected_eps[0], "RSTR"].group]) == 1
            group_distribution = OrderedDict()

            group_links_limited = defaultdict(set)  # position -> available groups
            group_links = defaultdict(set)  # position -> available groups
            group_benefits = {}

            for group_name, group_labels in groups.items():
                group_eps = sum((lbl_lookup[i] for i in group_labels), [])
                group_benefits[group_name] = len(group_labels) - sum(
                    1 for ep in group_eps for args in ep.args.values()
                    if args.startswith("h"))

            try:
                for selected_ep in selected_eps:
                    group_distribution[selected_ep] = OrderedDict()
                    for name, value in selected_ep.args.items():
                        if not value.startswith("h"):
                            continue
                        group_distribution[selected_ep][name] = set()
                        for group_name in groups.keys():
                            group_labels = groups[group_name]
                            group_eps = sum((lbl_lookup[i] for i in group_labels), [])
                            eq_info = eq_reverse_lookup.get(group_name)
                            qeq_info = qeq_reverse_lookup.get(group_name)
                            scope_var = group_to_scope_var.get(group_name)
                            satisfy_scope_constraint = scope_var is None or scope_var == selected_ep.args.get("ARG0")
                            satisfy_qeq_constrait = group_name not in group_qeq_constraint or name == "BODY"

                            # position_info = child_info.get((selected_ep, name))
                            # is_not_top_qeq = (qeq_constraints and not any(i in qeq_constraints for i in group_labels))
                            # is_not_position_qeq = position_info and position_info.type == "QEQ" and position_info.group != group_name
                            # if is_not_top_qeq and is_not_position_qeq:
                            #     has_qeq_blocker = not all((i in terminal_eps or is_quantifier(i))
                            #                               for i in group_eps)
                            #     not_has_quantifier = not any(is_quantifier(i) for i in group_eps)
                            #     if has_qeq_blocker and not_has_quantifier:
                            #         continue
                            # constraint 1: bind EQ
                            if eq_info:
                                assert satisfy_scope_constraint, "scope is not correct"
                                assert satisfy_qeq_constrait, "qeq is not correct"
                                assert not qeq_info, "has qeq"
                                if eq_info == (selected_ep, name):
                                    group_links_limited[selected_ep, name].add(group_name)
                                    group_links[selected_ep, name].add(group_name)
                            # constraint 2: bind QEQ
                            elif qeq_info:
                                assert satisfy_scope_constraint, "scope is not correct"
                                assert satisfy_qeq_constrait, "qeq is not correct"
                                assert not eq_info, "has eq"
                                if qeq_info == (selected_ep, name):
                                    group_links_limited[selected_ep, name].add(group_name)
                                    group_links[selected_ep, name].add(group_name)
                            # constrait 3: scope variable
                            elif satisfy_scope_constraint and satisfy_qeq_constrait:
                                group_links[selected_ep, name].add(group_name)
                                if not child_info.get((selected_ep, name)):
                                    group_links_limited[selected_ep, name].add(group_name)
                        assert group_links_limited[selected_ep, name], f"no links for {encode_ep(selected_ep)} {name}"
                p_groups = reduce(operator.or_, group_links.values())
                all_groups = set(groups.keys())
                assert p_groups == all_groups, f"unmatched groups:{p_groups}!={all_groups}"

                # distribute groups to positions
                max_match = maximum_match(group_links_limited)
                assert len(max_match) == len(group_links_limited)
            except AssertionError as e:
                # traceback.print_exc()
                failed_structure = FailureInfo(selected_eps, get_group_eps(), self.qeq_lookup,
                                               eq_constraint, qeq_constraints,
                                               info=str(e))
                continue

            for group, (selected_ep, name) in max_match.items():
                group_distribution[selected_ep][name].add(group)

            used_group = set()
            for (selected_ep, name), avail_groups in reversed(list(group_links.items())):
                position_benefit = sum(group_benefits[i] for i in group_distribution[selected_ep][name]) - 1
                changed = True
                while position_benefit < 0 and changed:
                    changed = False
                    for group in avail_groups:
                        if group not in max_match and group not in used_group and group_benefits[group] > 0:
                            group_distribution[selected_ep][name].add(group)
                            used_group.add(group)
                            changed = True

            for (selected_ep, name), avail_groups in reversed(list(group_links.items())):
                for group in avail_groups:
                    if group not in max_match and group not in used_group:
                        group_distribution[selected_ep][name].add(group)
                        used_group.add(group)

            sub_info.append((ufset, selected_eps, groups, group_distribution, child_info, new_qeq_constraints))
            if is_simple_quantifier:
                best_idx = len(sub_info) - 1

        # for left_count in range(0 if qeq_target else 1, len(group_keys)):
        # for left_group_keys in itertools.permutations(group_keys, left_count):
        if not sub_info:
            if cycled:
                failed_structure = FailureInfo([], {"CYCLED": internal_eps | terminal_eps}, self.qeq_lookup,
                                               eq_constraint, qeq_constraints)
            raise MRSResolverError(failed_structure)

        if best_idx > 0:
            if self.greedy_simple_quantifier:
                sub_info = [sub_info[best_idx]]
            else:
                t = sub_info[0]
                sub_info[0] = sub_info[best_idx]
                sub_info[best_idx] = t

        failed_structure = None

        for current_info in sub_info:
            ufset, selected_eps, groups, group_distribution, child_info, new_qeq_constraints = current_info
            results = []

            failed = False
            for selected_ep, name_and_selected_groups in group_distribution.items():
                if not is_quantifier(selected_ep):
                    assert not new_qeq_constraints

                child_dict = OrderedDict()
                for name, selected_groups in name_and_selected_groups.items():
                    child_internal_eps = set()
                    child_terminal_eps = set()
                    for key in selected_groups:
                        for ep_label in groups[key]:
                            for ep in lbl_lookup[ep_label]:
                                if ep.label == selected_ep.label:
                                    continue
                                if ep in internal_eps:
                                    child_internal_eps.add(ep)
                                elif ep in terminal_eps:
                                    child_terminal_eps.add(ep)
                    eq_constraint = None
                    child_info_ = child_info.get((selected_ep, name))
                    this_qeq_constraints = set(i for i in new_qeq_constraints
                                               if ufset[i] in selected_groups)
                    if child_info_:
                        if child_info_.type == "EQ":
                            eq_constraint = selected_ep.args[name]
                        if child_info_.type == "QEQ":
                            this_qeq_constraints = this_qeq_constraints | {child_info_.target}
                    try:
                        result = self.solve_inner(child_internal_eps, child_terminal_eps, eq_constraint,
                                                  this_qeq_constraints)
                    except MRSResolverError as e:
                        result = e.value
                        assert result is not None
                        failed = True
                    child_dict[name] = result
                results.append(ResolvedMRS(selected_ep, child_dict))

            for terminal_ep in set(lbl_lookup[selected_eps[0].label]) & terminal_eps:
                results.append(ResolvedMRS(terminal_ep))

            if len(results) == 1:
                results = results[0]

            if not failed:
                return results
            else:
                failed_structure = results

        raise MRSResolverError(failed_structure)

    def check_correctness(self, resolved_mrs):
        # 1. check total ep count
        # all element prediction must in resolved MRS
        if isinstance(resolved_mrs, list):
            system_eps = []
            for mrs in resolved_mrs:
                system_eps.extend(mrs.collect_eps())
        else:
            system_eps = resolved_mrs.collect_eps()
        system_eps_counts = Counter(system_eps)
        assert all(i == 1 for i in system_eps_counts.values())
        assert self.eps_wrapped == set(system_eps), self.eps_wrapped ^ set(system_eps)

        # 2. check scope constraint
        self.check_scope(resolved_mrs, set(i.args["ARG0"] for i in system_eps
                                           if is_quantifier(i)))

        # 2. check EQ
        top_label = self.mrs_obj.top if self.mrs_obj.top in self.lbl_lookup else None
        self.check_eq(resolved_mrs, top_label)

        # 3. check QEQ
        top_qeq_target = self.qeq_lookup.get(self.mrs_obj.top)
        pending_qeq_target = {}
        if top_qeq_target:
            pending_qeq_target[top_qeq_target] = ("TOP",)
        solved_qeq = self.check_qeq(resolved_mrs, pending_qeq_target)
        for hi, lo in self.qeq_lookup.items():
            if lo not in solved_qeq:
                raise MRSCheckerError(f"{self.qeq_parents[lo]}.{hi} QEQ {lo} not solved!")
        return True

    def check_eq(self, resolved_mrs, top_label):
        # check top correctness
        if top_label is not None:
            if isinstance(resolved_mrs, list):
                top_correct = any(i.ep.label == top_label for i in resolved_mrs)
            else:
                top_correct = (resolved_mrs.ep.label == top_label)
            if not top_correct:
                raise MRSCheckerError(f"EQ Not satisfied for {top_label}")

        if isinstance(resolved_mrs, list):
            for i in resolved_mrs:
                self.check_eq(i, None)
        else:
            assert isinstance(resolved_mrs, ResolvedMRS)
            for key, value in resolved_mrs.ep.args.items():
                if value.startswith("h"):
                    if value in self.lbl_lookup:
                        # print("Check EQ: ", resolved_mrs.ep, key, value)
                        self.check_eq(resolved_mrs.children[key], value)
                    else:
                        self.check_eq(resolved_mrs.children[key], None)

    def check_qeq(self, resolved_mrs, pending_target):
        if isinstance(resolved_mrs, list):
            solved_qeq = {}
            for i in resolved_mrs:
                solved_qeq.update(self.check_qeq(i, pending_target))
            return solved_qeq
        else:
            assert isinstance(resolved_mrs, ResolvedMRS)
            if len(resolved_mrs.children) == 0:
                # terminal
                solved_qeq = {}
                if resolved_mrs.ep.label in pending_target:
                    solved_qeq[resolved_mrs.ep.label] = pending_target[resolved_mrs.ep.label] + (resolved_mrs.ep.label,)
                return solved_qeq
            else:
                # internal
                solved_qeq = {}
                if resolved_mrs.ep.label in pending_target:
                    solved_qeq[resolved_mrs.ep.label] = pending_target[resolved_mrs.ep.label] + (resolved_mrs.ep.label,)
                    pending_target = {k: v + (resolved_mrs.ep.label,) for k, v in pending_target.items()
                                      if k != resolved_mrs.ep.label}
                else:
                    pending_target = {k: v + (resolved_mrs.ep.label,) for k, v in pending_target.items()}
                for k, v in resolved_mrs.children.items():
                    if k == "BODY":
                        target = dict(pending_target)
                    else:
                        target = {}
                    original_handle = resolved_mrs.ep.args[k]
                    if original_handle in self.qeq_lookup:
                        target_handle = self.qeq_lookup[original_handle]
                        assert target_handle not in target
                        target[target_handle] = (original_handle,)
                    solved_qeq.update(self.check_qeq(v, target))
                return solved_qeq

    def check_scope(self, resolved_mrs, all_scope_vars, current_scope_vars=()):
        if isinstance(resolved_mrs, list):
            for i in resolved_mrs:
                self.check_scope(i, all_scope_vars, current_scope_vars)
        else:
            if is_quantifier(resolved_mrs.ep):
                this_scope_var = resolved_mrs.ep.args["ARG0"]
                assert this_scope_var in all_scope_vars
                child_scope_vars = current_scope_vars + (this_scope_var,)
            else:
                child_scope_vars = current_scope_vars
            for i in resolved_mrs.ep.args.values():
                if i in all_scope_vars and i not in child_scope_vars:
                    raise MRSCheckerError(f"{resolved_mrs.ep} use out-of-scope var {i}")
            for key, value in resolved_mrs.children.items():
                self.check_scope(value, all_scope_vars, child_scope_vars)


class Timeout:
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)


def main():
    #  mrs_literal = """
    # [ RELS: <
    #        [ _dog_n_1<1:2>
    #          LBL: h0
    #          ARG0: x14 ]
    #        [ _every_q<0:1>
    #          LBL: h7
    #          ARG0: x14
    #          RSTR: h4
    #          BODY: h12 ]
    #        [ _chase_v_1<2:3>
    #          LBL: h16
    #          ARG1: x14
    #          ARG0: x15
    #          ARG2: x10 ]
    #        [ _cat_n_1<4:5>
    #          LBL: h6
    #          ARG0: x10 ]
    #        [ _some_q<3:4>
    #          LBL: h9
    #          BODY: h5
    #          RSTR: h1
    #          ARG0: x10 ] >
    #  HCONS: < h4 QEQ h0 h1 QEQ h6 > ]
    #  """
   #  mrs_obj = """
   # [ RELS: <
   #        [ _some_q<0:1>
   #          LBL: h7
   #          BODY: h0
   #          RSTR: h6
   #          ARG0: x1 ]
   #        [ unknown<0:3>
   #          LBL: h3
   #          ARG: x1
   #          ARG0: x9 ]
   #        [ _dog_n_1<1:2>
   #          LBL: h4
   #          ARG0: x1 ] >
   #  HCONS: < h6 QEQ h4 > ]
   #  """

    mrs_obj = """
   [ RELS: <
          [ _cat_n_1<5:6>
            LBL: h7
            ARG0: x4 ]
          [ _some_q<4:5>
            LBL: h19
            RSTR: h0
            ARG0: x4
            BODY: h2 ]
          [ _chase_v_1<3:4>
            LBL: h11
            ARG2: x4
            ARG1: x17
            ARG0: x18 ]
          [ _probable_a_1<2:3>
            LBL: h15
            ARG1: h10
            ARG0: x20 ]
          [ _dog_n_1<1:2>
            LBL: h13
            ARG0: x17 ]
          [ _every_q<0:1>
            LBL: h5
            ARG0: x17
            RSTR: h16
            BODY: h1 ] >
    HCONS: < h16 QEQ h13 h10 QEQ h11 h0 QEQ h7 > ]
    """

    mrs_obj = simplemrs.loads_one(mrs_obj)
    resolver = MRSResolver(mrs_obj)
    resolved_mrs = resolver.solve()
    print(to_string(resolved_mrs))

    home = os.path.expanduser("~")
    deepbank_export_path = home + "/Development/large-data/deepbank1.1/export/"
    total_count = 0
    success_count = 0
    timeout_count = 0
    for file_path in sorted(Path(deepbank_export_path).glob("wsj20*/*.gz")):
        # for file_path in [home + "/Development/large-data/deepbank1.1/export/wsj04b/20416038.gz"]:
        # for file_path in [home + "/Development/large-data/deepbank1.1/export/wsj02d/20296051.gz"]:
        # for file_path in [home + "/Development/large-data/deepbank1.1/export/wsj04b/20416038.gz"]:
        print(file_path)
        total_count += 1
        with gzip.open(file_path, "rb") as f:
            fields = f.read().decode("utf-8").strip().split("\n\n")
            mrs_literal = fields[-3]
        mrs_obj = simplemrs.loads_one(mrs_literal)
        try:
            resolver = MRSResolver(mrs_obj)
            with Timeout(seconds=30):
                resolved_mrs = resolver.solve()
            print(to_string(resolved_mrs))
            resolver.check_correctness(resolved_mrs)
            success_count += 1
        except TimeoutError:
            timeout_count += 1
            print("Timeout!")
        except MRSResolverError as e:
            print("Cannot solve this!!")
            print(to_string(e.value))
            # check_mrs_correctness(mrs_obj, e.value)
    print(success_count / total_count, timeout_count)


if __name__ == '__main__':
    main()
