import torch
import numpy as np
INF = 1e5

# Find lca: no child node in common ancestors
def find_lca(common_ancestors, children):
    lca = -1
    for idx in common_ancestors:
        tag = 0 # indicates no child in this set
        for child in children[idx]:
            if child in common_ancestors:
                tag = 1
                break
        if tag == 0:
            lca = idx
            break
    
    assert lca >= 0, "lca computes in a wrong way."
    return lca

# Compute all distance to path_nodes
def compute_dist(path_nodes, head, sent_len):
    dist = [-1 for i in range(sent_len)]
    for i in range(sent_len):
        if dist[i] < 0: # not compute yet
            if i in path_nodes:
                dist[i] = 0
            else:
                stack = []
                h = i
                while h not in path_nodes:
                    stack.append(h)
                    if head[h] == h: # meet the root
                        break
                    h = head[h]

                if h in path_nodes:
                    for d,idx in enumerate(reversed(stack)):
                        dist[idx] = d + 1
                else:
                    for idx in stack:
                        dist[idx] = INF # not in the lca subtree

    return dist

# Transfer edge messages in the tree to adj matrix
def tree_to_adj(root, nodes, adj, children, sent_len, self_loop=True):
    queue = [root]
    for i in queue:
        for child in children[i]:
            if child in nodes:
                adj[i][child] = 1
                queue.append(child)

    if self_loop:
        for i in range(sent_len):
            adj[i][i] = 1

    return adj

def pruning_with_entities(entities, adj_fw, head, children, sent_len, max_len, prune=1):
    INF = 1e6
    common_ancestors = None
    entity_ancestors = set()
    for i in entities:
        ancestors = [] # Record all ancestors of this token
        h = i
        while head[h] != h:
            ancestors.append(h)
            h = head[h]
        ancestors.append(h) # Add the root

        entity_ancestors.update(ancestors)
        if common_ancestors == None:
            common_ancestors = set(ancestors)
        else:
            common_ancestors.intersection_update(ancestors)
    assert len(common_ancestors) > 0, "Entities must have common ancestors."

    # Record shortest dependency path nodes
    lca = find_lca(common_ancestors, children)
    path_nodes = entity_ancestors.difference(common_ancestors)
    path_nodes.add(lca)
    
    # Prune irrelevant nodes
    dist = compute_dist(path_nodes, head, sent_len) # Compute all distance to path_nodes
    nodes = [i for i in range(sent_len) if dist[i] <= prune]
    root = lca

    adj = np.zeros((max_len, max_len),  dtype=np.int8)
    adj = tree_to_adj(root, nodes, adj, children, sent_len)
    return adj   