# -*- coding: utf-8 -*-

import math

import pyshrg


def recover_subgraph_derivation(derivation, graph: pyshrg.EdsGraph):
    subgraphs = [set(_[1]) for _ in derivation]
    external_nodes_list = [None] * len(derivation)

    def _loop(root):
        rule_index, matched_edges, external_nodes_list[root], *child_indices = derivation[root]
        for step in child_indices:
            _loop(step)
            subgraphs[root].update(subgraphs[step])  # update edges of children

        if len(child_indices) == 2:
            left_index, right_index = child_indices
            left_edges = subgraphs[left_index]
            right_edges = subgraphs[right_index]
            if not left_edges and right_edges:  # left subtree has no semantics part
                child_indices = derivation[right_index][1:]
                derivation[right_index] = derivation[left_index] = None
            elif not right_edges and left_edges:  # right subtree has no semantics part
                child_indices = derivation[left_index][1:]
                derivation[right_index] = derivation[left_index] = None

        if not child_indices:
            child_indices = -1, -1
        elif len(child_indices) == 1:
            child_indices = child_indices[0], -1
        derivation[root] = rule_index, *child_indices

    _loop(-1)

    edges_map = graph.edges_map
    nodes_map = {node.id: node.index for node in graph.nodes}
    partitions = []
    for step in range(len(derivation) - 1, -1, -1):
        item = derivation[step]
        if item is None:
            partitions.append(None)
            continue
        edges = subgraphs[step]
        external_nodes = external_nodes_list[step]
        _, left_index, right_index = item
        left_edges = subgraphs[left_index] if left_index != -1 else None
        right_edges = subgraphs[right_index] if right_index != -1 else None
        if left_edges:
            edges -= left_edges
        if right_edges:
            edges -= right_edges
        partitions.append((sorted({edges_map[edge] for edge in edges}),
                           sorted({edges_map[edge]
                                   for edge in left_edges}) if left_edges else None,
                           sorted({edges_map[edge]
                                   for edge in right_edges}) if right_edges else None,
                           tuple(nodes_map[node] for node in external_nodes)
                           if external_nodes is not None else None))
    partitions.reverse()
    return partitions, derivation


def find_best_derivation(context, root):
    if root.cfg_index == -2000:
        return root.score

    max_score = float('-inf')
    max_subgraph = root

    sum_scores = 0
    scores = []
    for subgraph in root.all():
        sum_scores += math.exp(subgraph.score)
        scores.append(subgraph.score)
    log_sum_scores = math.log(sum_scores)

    for subgraph, current_score in zip(root.all(), scores):
        current_score = current_score - log_sum_scores
        left, right = context.split_subgraph(subgraph)
        if left:
            current_score += find_best_derivation(context, left)
        if right:
            current_score += find_best_derivation(context, right)

        if max_score < current_score:
            max_score = current_score
            max_subgraph = subgraph

    if root is not max_subgraph:
        root.swap(max_subgraph)

    root.cfg_index = -2000
    root.score = max_score

    return root.score
