"""
Get height and depth from edges of DAG(s).

python getHeightsAndDepths.py VOCAB_FILE EDGE_FILE

VOCAB_FILE has a node per line as below.

```vocab.txt
0
1
2
3
4
```

Line numbers are used as node ids.
So arbitrary node names are fine.

EDGE_FILE has a directed edge per line as below.

```edge.txt
0 1
1 2
3 4
```

There are three edges from 0-1, 1-2, and 3-4.
Numbers are corresponding to node ids defined in VOCAB_FILE.

The script traverses DAGs and outputs a tsv file with the following structure.


- Node id
- Maximum height
- Number of descendant nodes
- Maximum depth
- Number of ancestor nodes


"""

import sys
import tqdm
import numpy as np

import multiprocessing as mp


assert len(sys.argv) == 3, 'Usage: python this.py EDGE_FILE VOCAB_FILE'

VOCAB = sys.argv[1]
EDGE = sys.argv[2]

USE_MP = True

with open(EDGE) as f:
    edges = [list(map(int, l.strip().split())) for l in f]


with open(VOCAB) as f:
    vocab = [l.strip() for l in f]


n = len(vocab)


graph = np.zeros((n, n)).astype(np.bool)


for e in edges:
    graph[e[0], e[1]] = True


def traverse(index, height_or_depth, visited_nodes=None):
    if visited_nodes is None:
        visited_nodes = np.zeros(graph.shape[0]).astype(np.bool)

    next_nodes = graph[:, index] if height_or_depth == 'd' else graph[index, :]

    next_nodes = np.where(next_nodes)[0]

    visited_nodes[index] = True

    if len(next_nodes) == 0:
        return 0, np.sum(visited_nodes)-1
    else:
        max_h_or_d = -1
        for nn in next_nodes:
            h_or_d, _ = traverse(nn, height_or_depth, visited_nodes)
            if h_or_d > max_h_or_d:
                max_h_or_d = h_or_d
        return max_h_or_d+1, np.sum(visited_nodes)-1


def traverse_height(begin_node):
    try:
        return traverse(begin_node, 'h')
    except RecursionError as e:
        assert False, f'Infinite recursion occurred. A loop might exists. ' \
            f'Check the {begin_node}-th node.'


def traverse_height_batch(indices):
    return {i: traverse_height(i) for i in tqdm.tqdm(indices, desc='Traversing height')}


def traverse_depth(begin_node):
    try:
        return traverse(begin_node, 'd')
    except RecursionError as e:
        assert False, f'Infinite recursion occurred. A loop might exists. ' \
            f'Check the {begin_node}-th node.'


def traverse_depth_batch(indices):
    return {i: traverse_depth(i) for i in tqdm.tqdm(indices, desc='Traversing depth')}


ids = list(range(n))

if USE_MP:
    workers = mp.cpu_count()
    ids_spl = np.array_split(ids, workers)

    with mp.Pool(workers) as p:
        results = p.map(traverse_height_batch, ids_spl)
        heights = {}
        [heights.update(r) for r in results]

        # Sorted by id
        heights = [v for (k, v) in sorted(heights.items(), key=lambda x: x[0])]

    with mp.Pool(workers) as p:
        results = p.map(traverse_depth_batch, ids_spl)
        depths = {}
        [depths.update(r) for r in results]

        # Sorted by id
        depths = [v for (k, v) in sorted(depths.items(), key=lambda x: x[0])]


else:
    heights = [traverse_height(i) for i in tqdm.tqdm(ids, desc='Traversing height')]
    depths = [traverse_depth(i) for i in tqdm.tqdm(ids, desc='Traversing depth')]


with open(EDGE + ".height", "w") as f:
    for i, (h, d) in enumerate(zip(heights, depths)):
        h, h_visited_nodes = h
        d, d_visited_nodes = d

        f.write(f'{i}\t{h}\t{h_visited_nodes}\t{d}\t{d_visited_nodes}\n')

