# -*- coding: utf-8 -*-

import hashlib
import itertools
from collections import defaultdict

from framework.common.utils import MethodFactory

from .hyper_graph import GraphNode

EP_PERMUTATIONS = MethodFactory()


def _compute_edge_hashes(node_hashes, edge, index):
    """hash(edge) = hash(edge.label#external_count#hash(node_1)#hash(node_2)#...)"""
    md5_obj = hashlib.md5((edge.label + '#' + str(index)).encode())
    for adj_node in edge.nodes:
        md5_obj.update(node_hashes[adj_node] + b'#')
    return md5_obj.digest()


def _compute_sibling_hashes(node_rename_map, node_hashes, edges_by_node, node):
    """hash(node)=hash(set of hash(sibling_edges)#node_name if exist)"""
    md5_obj = hashlib.md5()
    edge_hashes = sorted(_compute_edge_hashes(node_hashes, edge, index)
                         for edge, index in edges_by_node[node])
    for hash_ in edge_hashes:
        md5_obj.update(hash_)

    if node_rename_map.get(node) is not None:
        md5_obj.update(('#' + node_rename_map[node].name).encode('utf-8'))
    return md5_obj.digest()


def _recompute_hashes(nodes, node_rename_map, node_hashes, edges_by_node, external_nodes):
    new_node_hashes = {}
    # recompute hashes
    for node in nodes:
        md5_obj = hashlib.md5()
        md5_obj.update(_compute_sibling_hashes(node_rename_map, node_hashes, edges_by_node, node))
        md5_obj.update(b'\x01' if node in external_nodes else b'\x00')
        new_node_hashes[node] = md5_obj.digest()
    return new_node_hashes


def compute_names_of_nodes(edges, nodes, external_nodes):
    node_rename_map = {}

    edges_by_node = defaultdict(list)  # node -> (edge, index of this node in this edge)
    for edge in edges:
        for index, node in enumerate(edge.nodes):
            edges_by_node[node].append((edge, index))

    default_hash = hashlib.md5(b'13').digest()
    node_hashes = {node: default_hash for node in nodes}  # node -> hash

    def _compute():
        nonlocal node_hashes
        for cycle in range(len(nodes) + 10):
            node_hashes = _recompute_hashes(nodes, node_rename_map, node_hashes,
                                            edges_by_node, external_nodes)

    _compute()

    node_hashes_original = dict(node_hashes)

    assert len(nodes) == len(node_hashes)
    node_count = len(node_hashes)
    while len(node_rename_map) < node_count:
        nodes_in_order = sorted(node_hashes.items(), key=lambda x: x[1])
        has_symmetric = False
        for index, (node, hash_value) in enumerate(nodes_in_order):
            if index != node_count - 1 and nodes_in_order[index + 1][1] == hash_value:
                # Detect symmetric
                has_symmetric = True
                assert node not in node_rename_map
                node_rename_map[node] = GraphNode(str(len(node_rename_map)))

                _compute()

                break

        if not has_symmetric:
            for node, hash_value in nodes_in_order:
                if node not in node_rename_map:
                    node_rename_map[node] = GraphNode(str(len(node_rename_map)))
            break

    return node_hashes_original, node_rename_map


@EP_PERMUTATIONS.register('stick+span')
def compute_ep_permutation_by_hash(cfg_node, edges, external_nodes,
                                   node_hashes_original, node_rename_map, **_kwargs):
    left_and_right_span = None
    if len(cfg_node.children) == 2:
        left_and_right_span = (cfg_node.children[0].span, cfg_node.children[1].span)

    comment = None
    ep_permutation = None
    pending = []
    for permutation in itertools.permutations(external_nodes):
        if any(edge.nodes == permutation for edge in edges):
            pending.append(permutation)

    if len(pending) == 1:
        ep_permutation = [node_rename_map[node] for node in pending[0]]
        comment = {'EP permutation': 'Stick hyperedge to one edge'}
    elif len(external_nodes) == 2 and left_and_right_span is not None:
        left_span, right_span = left_and_right_span
        left_node = [edge.nodes[0] for edge in edges
                     if len(edge.nodes) == 1 and edge.span == left_span]
        right_node = [edge.nodes[0] for edge in edges
                      if len(edge.nodes) == 1 and edge.span == right_span]
        if left_node and right_node:
            left_node = left_node[0]
            right_node = right_node[0]
            if node_hashes_original[left_node] != node_hashes_original[right_node] \
               and {left_node, right_node} == external_nodes:
                comment = {'EP permutation':
                           'judge #EP2 edge direction by spans of left and right node'}
                ep_permutation = [node_rename_map[left_node], node_rename_map[right_node]]

    return ep_permutation, comment


@EP_PERMUTATIONS.register('combine', suffix_args='key_method')
def compute_ep_permutation_by_keys(external_nodes, node_rename_map, key_method, extra_infos,
                                   **_kwargs):
    key_names = key_method.split('+')

    def _get_key(node):
        key = []
        if 'start' in key_names:
            key.append(extra_infos[node][0])
        elif 'all' in key_names:
            key.append(tuple(extra_infos[node]))
        elif 'complete' in key_names:
            key.append(tuple(extra_infos[node][::-1]))

        assert key, f'empty key for {node}'
        return tuple(key)

    external_nodes = list(external_nodes)
    keys = [_get_key(node) for node in external_nodes]
    indexed_keys = sorted((_get_key(node), index) for index, node in enumerate(external_nodes))
    ep_permutation = [node_rename_map[external_nodes[index]] for _, index in indexed_keys]

    comment_string = ' *** keys: ' + \
        str({external_nodes[index].name: key for key, index in indexed_keys})

    comment = {'EP permutation':
               ('' if len(set(keys)) == len(keys) else 'partial ') + key_method + comment_string}

    return ep_permutation, comment
