# -*- coding: utf-8 -*-H

from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

from dataclasses import dataclass

from framework.common.logger import LOGGER

from .const_tree import LABEL_SEP, ConstTree, Lexicon
from .hyper_graph import GraphNode, HyperEdge, HyperGraph
from .shrg_alignment import find_aligned_edge
from .shrg_compound_split import NodeDistributor
from .shrg_detect import DETECT_FUNCTIONS
from .shrg_permute import EP_PERMUTATIONS, compute_names_of_nodes

# Internal node: rhs = Sequence[(cfg_rhs_1, corresponding hrg edge1), ...]
# Leaf node: rhs = Sequence[(lexicon, None)]
# Invalid rule: rhs = None
CFGRuleRHS = Optional[Sequence[Union[Tuple[str, Optional[HyperEdge]],
                                     Tuple[Lexicon, Sequence[HyperEdge]]]]]

REMOVE_NULL_SEMANTIC_OPTIONS = ['merge-single', 'merge-both', 'delete', 'ignore_punct']


class CFGRule:
    __slots__ = ('lhs', 'rhs')

    def __init__(self, lhs: str, rhs: CFGRuleRHS):
        self.lhs = lhs
        self.rhs = tuple(rhs)

    def __hash__(self):
        return hash(self.lhs) ^ hash(self.rhs)

    def __eq__(self, other):
        return self.lhs == other.lhs and self.rhs == other.rhs

    def __str__(self):
        return f'{self.lhs} => {self.rhs}'

    def __repr__(self):
        return str(self)


class HRGRule:
    __slots__ = ('lhs', 'rhs', 'comment')

    def __init__(self, lhs: HyperEdge, rhs: HyperGraph, comment: Optional[Dict[str, str]]=None):
        self.lhs = lhs
        self.rhs = rhs
        self.comment = comment or {}

    def __hash__(self):
        return hash(self.lhs) ^ hash(self.rhs)

    def __eq__(self, other):
        return self.lhs == other.lhs and self.rhs == other.rhs

    def __str__(self):
        return f'{self.lhs} => {self.rhs.edges} ### {self.comment}'

    def __repr__(self):
        return str(self)


class SHRGRule:
    __slots__ = ('cfg', 'hrg')

    def __init__(self, cfg: CFGRule, hrg: HRGRule):
        self.cfg = cfg
        self.hrg = hrg

    def __hash__(self):
        return hash(self.cfg) ^ hash(self.hrg)

    def __eq__(self, other):
        return self.cfg == other.cfg and self.hrg == other.hrg

    def __str__(self):
        return f'SHRGRule(\n  cfg:  {self.cfg}\n  hrg:  {self.hrg}\n)'

    def __repr__(self):
        return str(self)


@dataclass
class DerivationInfo:
    hyper_graph: HyperGraph
    last_new_edge: HyperEdge
    all_edges: Set[HyperEdge]
    internal_nodes: Set[GraphNode]
    external_nodes: Set[GraphNode]
    reversed_rename_map: Dict[GraphNode, GraphNode]


def extract_hrg_rule(cfg_node, edges, internal_nodes, external_nodes,
                     ep_permutation_methods,
                     extra_infos=None):
    node_rename_map = {}
    nodes = internal_nodes | external_nodes

    node_hashes_original, node_rename_map = compute_names_of_nodes(edges, nodes, external_nodes)

    rhs = HyperGraph(node_rename_map.values(),
                     (edge.new((node_rename_map[node] for node in edge.nodes)) for edge in edges))

    if len(external_nodes) == 1:
        ep_permutation = tuple(node_rename_map[node] for node in external_nodes)
        comment = {}
    else:
        ep_permutation = None
        for method in ep_permutation_methods:
            ep_permutation, comment = EP_PERMUTATIONS.invoke(method, **locals())
            if ep_permutation is not None:
                break

        if ep_permutation is None:
            comment = {'EP permutation': 'arbitrary order'}
            ep_permutation = sorted((node_rename_map[node] for node in external_nodes),
                                    key=lambda x: int(x.name))

    lhs = HyperEdge(ep_permutation, label=cfg_node.tag, is_terminal=False)

    return node_rename_map, HRGRule(lhs, rhs, comment)


def _copy_edge(edge, is_terminal):
    return HyperEdge(edge.nodes, label=edge.label, is_terminal=is_terminal, span=edge.span)


def _rewrite_nonterminal_edges(all_edges, cfg_node):
    nonterminal_edges = [edge for edge in all_edges if not edge.is_terminal]

    if len(nonterminal_edges) != 2 or \
       len(set(edge.to_tuple() for edge in nonterminal_edges)) == len(nonterminal_edges):
        return all_edges

    children = cfg_node.children
    assert len(children) == 2

    c1, c2 = children
    assert isinstance(c1, ConstTree) and isinstance(c2, ConstTree)
    assert c1.span != c2.span
    label_to_order = {(c1.tag, c1.span): 'left', (c2.tag, c2.span): 'right'}

    return {
        edge if edge.is_terminal else _copy_edge(edge, label_to_order[edge.label, edge.span])
        for edge in all_edges
    }


def extract_shrg_rule(hyper_graph: HyperGraph, const_tree: ConstTree,
                      detect_function=None,
                      fully_lexicalized=False,
                      return_derivation_infos=False,
                      remove_null_semantic=None,
                      graph_type='eds',
                      sentence_id='',
                      ep_permutation_methods=['stick+span'],
                      node_distributior_fn=NodeDistributor):
    kwargs = {'hyper_graph': hyper_graph,
              'const_tree': const_tree,
              'sentence_id': sentence_id,
              'detect_function': detect_function,
              'graph_type': graph_type,
              'fully_lexicalized': fully_lexicalized,
              'node_distributior_fn': node_distributior_fn}

    ep_permutation_methods = list(ep_permutation_methods)
    if ep_permutation_methods == ['stick+span']:
        return extract_shrg_rule_internal(**kwargs,
                                          ep_permutation_methods=ep_permutation_methods,
                                          return_derivation_infos=return_derivation_infos,
                                          remove_null_semantic=remove_null_semantic,
                                          extra_infos=None)

    const_tree.add_parents()
    _, (_, edge_blame_dict), derivation_infos = \
        extract_shrg_rule_internal(**kwargs, ep_permutation_methods=[],
                                   return_derivation_infos=True,
                                   remove_null_semantic=None,
                                   extra_infos=None)

    cfg_nodes = list(const_tree.traverse_postorder())

    edges_by_step = [[] for _ in range(len(cfg_nodes))]
    combine_infos = [{} for _ in range(len(cfg_nodes))]

    for edge, step in edge_blame_dict.items():
        edges_by_step[step].append(edge)

    all_matched_edges = []
    for step, cfg_node in enumerate(cfg_nodes):
        matched_edges = set(edges_by_step[step])
        if isinstance(cfg_node.children[0], ConstTree):
            for child in cfg_node.children:
                matched_edges.update(all_matched_edges[child.index])
        all_matched_edges.append(matched_edges)

    assert len(all_matched_edges[-1]) == len(hyper_graph.edges)

    for step, cfg_node in enumerate(cfg_nodes):
        derivation_info = derivation_infos[step]
        if derivation_info is None:
            continue
        frozen_nodes = set()

        combine_info = combine_infos[step]

        parent_infos = []
        current_node = cfg_node
        while current_node.parent is not None:
            matched_edges = (all_matched_edges[current_node.parent.index]
                             - all_matched_edges[current_node.index])
            parent_infos.append((current_node.parent.index,
                                 {node for edge in matched_edges for node in edge.nodes}))
            current_node = current_node.parent

        for node in derivation_info.external_nodes:
            assert node not in frozen_nodes

            for parent_step, related_nodes in parent_infos:
                if node in related_nodes:  # new edges are merge with current node
                    combine_info.setdefault(node, []).append(parent_step)

        for node in derivation_info.internal_nodes:
            frozen_nodes.add(node)

    return extract_shrg_rule_internal(**kwargs,
                                      ep_permutation_methods=ep_permutation_methods,
                                      return_derivation_infos=return_derivation_infos,
                                      remove_null_semantic=remove_null_semantic,
                                      extra_infos=combine_infos)


def extract_shrg_rule_internal(hyper_graph: HyperGraph, const_tree: ConstTree,
                               detect_function,
                               graph_type,
                               sentence_id,
                               ep_permutation_methods,
                               extra_infos,
                               fully_lexicalized,
                               return_derivation_infos,
                               remove_null_semantic,
                               node_distributior_fn):
    """ Extract rules from give hyper_graph and constituent tree. """
    detect_function = DETECT_FUNCTIONS.normalize(detect_function)

    edge_blame_dict: Dict[HyperEdge, int] = {}
    node_blame_dict: Dict[HyperEdge, int] = {}
    boundary_node_dict: Dict[GraphNode, int] = {}

    last_new_edge: Optional[HyperEdge] = None
    shrg_rules: List[Tuple[SHRGRule, HyperEdge]] = []
    derivation_infos: List[Optional[DerivationInfo]] = []

    node_distribution = node_distributior_fn(hyper_graph, const_tree, graph_type,
                                             fully_lexicalized=fully_lexicalized,
                                             logger=LOGGER).solve()

    const_tree.add_postorder_index()  # give every tree node an index
    for step, cfg_node in enumerate(const_tree.traverse_postorder()):
        cfg_node.calculate_span()

        collected_pred_edges = node_distribution[cfg_node]
        for child_node in cfg_node.children:
            if isinstance(child_node, ConstTree):
                _, child_new_edge = shrg_rules[child_node.index]
                if child_new_edge is not None:
                    collected_pred_edges.add(child_new_edge)

        result = detect_function(hyper_graph, cfg_node, collected_pred_edges)

        if result is None:
            cfg_node.has_semantics = False
            cfg_rhs = tuple((child if isinstance(child, Lexicon) else child.tag, None)
                            for child in cfg_node.children)

            shrg_rules.append((SHRGRule(CFGRule(cfg_node.tag, cfg_rhs), None), None))
            if return_derivation_infos:
                derivation_infos.append(None)
            continue

        cfg_node.has_semantics = True
        all_edges, internal_nodes, external_nodes, detect_comment = result

        # # attach additional information
        node_rename_map, hrg_rule = \
            extract_hrg_rule(cfg_node,
                             _rewrite_nonterminal_edges(all_edges, cfg_node),
                             internal_nodes, external_nodes,
                             ep_permutation_methods=ep_permutation_methods,
                             extra_infos=extra_infos[step] if extra_infos else None)
        assert len(hrg_rule.rhs.edges) == len(all_edges)

        hrg_rule_nodes = hrg_rule.rhs.nodes
        assert set(int(n.name) for n in hrg_rule_nodes) == set(range(len(hrg_rule_nodes)))

        if detect_comment is not None:
            hrg_rule.comment['DetectInner'] = detect_comment

        if not external_nodes and cfg_node is not const_tree:
            # If external nodes is empty and current node is not root, select first internal node
            node = None
            for internal_node in internal_nodes:
                if node_rename_map[internal_node].name == '0':
                    node = internal_node
                    break
            assert node

            internal_nodes.remove(node)
            external_nodes = set([node])
            hrg_rule.lhs.nodes = (node_rename_map[node], )
            hrg_rule.comment['Detect'] = 'Use first node as external node'

            # logger.warning('empty-external-nodes: %s/%d %s', sentence_id, step, cfg_node.tag)

        for edge in all_edges:
            if edge.is_terminal:
                assert edge not in edge_blame_dict
                edge_blame_dict[edge] = step

        for node in internal_nodes:
            assert node not in node_blame_dict
            node_blame_dict[node] = step

        reversed_node_rename_map = {node: original_node
                                    for original_node, node in node_rename_map.items()}
        if return_derivation_infos:
            derivation_infos.append(DerivationInfo(hyper_graph=hyper_graph,
                                                   last_new_edge=last_new_edge,
                                                   all_edges=all_edges,
                                                   internal_nodes=internal_nodes,
                                                   external_nodes=external_nodes,
                                                   reversed_rename_map=reversed_node_rename_map))

        new_edge = HyperEdge((reversed_node_rename_map[node] for node in hrg_rule.lhs.nodes),
                             label=cfg_node.tag,
                             is_terminal=False,
                             span=cfg_node.span)
        boundary_node_dict[step] = tuple(new_edge.nodes)

        new_hyper_graph = HyperGraph(hyper_graph.nodes - internal_nodes,
                                     (hyper_graph.edges - all_edges) | {new_edge})

        hyper_graph = new_hyper_graph
        last_new_edge = new_edge

        if isinstance(cfg_node.children[0], Lexicon):
            assert len(cfg_node.children) == 1, 'Stange condition'
            cfg_rhs = find_aligned_edge(sentence_id, cfg_node.children[0], hrg_rule.rhs)
        else:
            assert len(cfg_node.children) == 2
            assert all(isinstance(child_node, ConstTree) for child_node in cfg_node.children)

            cfg_rhs = []
            for index, child_node in enumerate(cfg_node.children):
                if not child_node.has_semantics:
                    cfg_rhs.append((child_node.tag, None))
                    continue

                _, target_edge = shrg_rules[child_node.index]
                if target_edge.label != child_node.tag and target_edge.carg != child_node.tag:
                    LOGGER.warning('Non-consistent CFG and HRG: %s', sentence_id)
                    cfg_rhs = None
                    break

                target_edge = HyperEdge([node_rename_map[node] for node in target_edge.nodes],
                                        label=target_edge.label,
                                        is_terminal=target_edge.is_terminal)
                if target_edge not in hrg_rule.rhs.edges:
                    target_edge = _copy_edge(target_edge, ('left' if index == 0 else 'right'))
                    LOGGER.debug("%s %s %s", sentence_id, target_edge, hrg_rule.rhs.edges)
                    assert target_edge in hrg_rule.rhs.edges

                cfg_rhs.append((child_node.tag, target_edge))

        cfg_lhs = cfg_node.tag
        if cfg_lhs.startswith('ROOT'):  # merge all ROOT labels
            assert hrg_rule.lhs.label == cfg_lhs
            cfg_lhs = hrg_rule.lhs.label = 'ROOT'

        if cfg_rhs is not None:
            rule = SHRGRule(CFGRule(cfg_lhs, tuple(cfg_rhs)), hrg_rule)
        else:
            rule = SHRGRule(CFGRule(cfg_lhs, None), None)

        shrg_rules.append((rule, new_edge))

    shrg_rules = [rule for rule, _ in shrg_rules]

    if return_derivation_infos:
        derivation_infos.append(DerivationInfo(hyper_graph=hyper_graph,
                                               last_new_edge=last_new_edge,
                                               all_edges=set(),
                                               internal_nodes=set(),
                                               external_nodes=set(),
                                               reversed_rename_map=reversed_node_rename_map))

    if remove_null_semantic is not None:
        remove_null_semantic_rules(const_tree, shrg_rules, remove_null_semantic)

    if return_derivation_infos:
        return shrg_rules, (node_blame_dict, edge_blame_dict), derivation_infos
    return shrg_rules, (node_blame_dict, edge_blame_dict), boundary_node_dict


def remove_null_semantic_rules(const_tree: ConstTree, shrg_rules: List[SHRGRule], option):
    use_delete = 'delete' in option
    use_merge_both = 'merge-both' in option
    use_merge_single = 'merge-single' in option
    ignore_punct = 'ignore_punct' in option
    assert not use_merge_both or not use_merge_single

    def _get_label(label, left_label, right_label, merge_left=True):
        if not use_merge_both and not use_merge_single:
            return label
        child_label = left_label if merge_left else right_label
        real_label = label.split(LABEL_SEP)[-1]  # label my be condensed
        real_child_label, *rest_labels = child_label.split('@', 1)
        if real_label != real_child_label:
            if use_merge_single:
                label = f'{label}!{child_label}'
            else:
                label = f'{label}!{left_label}!{right_label}'
        elif rest_labels:
            label = f'{label}@{rest_labels[0]}'
        return label

    cfg_nodes = list(const_tree.traverse_postorder())
    cfg_node2step = {cfg_node: step for step, cfg_node in enumerate(cfg_nodes)}
    for step, (cfg_node, rule) in enumerate(zip(cfg_nodes, shrg_rules)):
        if len(cfg_node.children) == 1:
            if isinstance(cfg_node.children[0], ConstTree):
                assert rule.hrg is not None, 'Strange condition'
            continue
        assert len(rule.cfg.rhs) == 2, 'Strange CFGRule '
        left, right = cfg_node.children
        left_rule = shrg_rules[cfg_node2step[left]]
        right_rule = shrg_rules[cfg_node2step[right]]
        if left_rule.hrg is None and right_rule.hrg is None:
            assert rule.hrg is None
            LOGGER.debug('null-semantic: both children of [%s]', cfg_node)
            continue
        if left_rule.hrg is not None and right_rule.hrg is not None:
            continue

        cfg = rule.cfg
        left_label = left_rule.cfg.lhs  # left child label
        right_label = right_rule.cfg.lhs  # right child label
        if left_rule.hrg is None:
            LOGGER.debug('null-semantic: left child of [%s]', cfg_node)
            label = _get_label(cfg.lhs, left_label, right_label, merge_left=False)

            # the nonterminal edge to be replaced with corresponding rules
            target_edge = cfg.rhs[1][1]
            assert right_rule.hrg.lhs.label.startswith(target_edge.label), 'Strange condition'

            # the left child has no semantics, we prepend all the lexicons in
            # left subtree to cfg items directly
            cfg_rhs = [] if use_delete else [(lexicon, None)
                                             for lexicon in left.generate_lexicons(ignore_punct)]
            cfg_rhs.extend(right_rule.cfg.rhs)

            child_external_nodes = right_rule.hrg.lhs.nodes
            child_nodes = right_rule.hrg.rhs.nodes
            child_edges = right_rule.hrg.rhs.edges
        else:  # right_rule.hrg is None
            LOGGER.debug('null-semantic: right child of [%s]', cfg_node)
            label = _get_label(cfg.lhs, left_label, right_label, merge_left=True)

            target_edge = cfg.rhs[0][1]
            assert left_rule.hrg.lhs.label.startswith(target_edge.label), 'Strange condition'

            # the right child has no semantics, we append all the lexicons in
            # right subtree to cfg items directly
            cfg_rhs = list(left_rule.cfg.rhs)
            if not use_delete:
                cfg_rhs.extend((lexicon, None) for lexicon in right.generate_lexicons(ignore_punct))

            child_external_nodes = left_rule.hrg.lhs.nodes
            child_nodes = left_rule.hrg.rhs.nodes
            child_edges = left_rule.hrg.rhs.edges

        assert len(target_edge.nodes) == len(child_external_nodes), 'Mismatch !!!'

        current_nodes = set(rule.hrg.rhs.nodes)
        current_edges = set(rule.hrg.rhs.edges)
        # first, remove edge which represents the subtree without semantics
        assert target_edge in current_edges, 'Target edge not in HRG graph !!!'
        current_edges.remove(target_edge)

        if current_edges:  # the rule introduces new edges
            # node in child rule => node in current rule
            node_mapping = dict(zip(child_external_nodes, target_edge.nodes))
            extra_nodes = sorted([child_node
                                  for child_node in child_nodes
                                  if child_node not in child_external_nodes],
                                 key=lambda x: int(x.name))
            # TODO: the order of rules
            # assign new name to node in child rule
            for index, child_node in enumerate(extra_nodes, len(current_nodes)):
                new_node = GraphNode(str(index))
                node_mapping[child_node] = new_node
                assert new_node not in current_nodes, 'Index of node in HRG graph is broken !!!'
                current_nodes.add(new_node)  # second, add nodes from child rule to current rule

            edge_mapping = {}
            for edge in child_edges:
                new_edge = edge.new((node_mapping[node] for node in edge.nodes), span=edge.span)
                edge_mapping[edge] = new_edge
                current_edges.add(new_edge)  # third, add edges from child rule to current rule

            for index, (_, edge) in enumerate(cfg_rhs):
                cfg_rhs[index] = (_, edge_mapping.get(edge))

            hrg_lhs_nodes = rule.hrg.lhs.nodes
        else:  # the rule introduces no edges
            current_nodes = child_nodes.copy()
            current_edges = child_edges.copy()
            hrg_lhs_nodes = rule.hrg.lhs.nodes
            if len(hrg_lhs_nodes) != len(child_external_nodes):
                assert rule.hrg.lhs.label == 'ROOT'
                hrg_lhs_nodes = ()
            else:
                hrg_lhs_nodes = tuple(child_external_nodes)

        LOGGER.debug('rule-change: %s ### %s %s', rule.cfg, label, cfg_rhs)
        cfg.lhs, cfg.rhs = label, tuple(cfg_rhs)
        rule.hrg = HRGRule(lhs=HyperEdge(nodes=hrg_lhs_nodes, label=label, is_terminal=False),
                           rhs=HyperGraph(nodes=current_nodes, edges=current_edges),
                           comment={'From': 'replace by its child'})
