"""
The module skipgram provides the class SkipgramFinder for finding occurrences
of skipgrams in sequences.

The SkipgramFinder constructor has the following parameters:

  `k`
    The "skip parameter". Default value: 5.

    A skipgram `g = (g[1], g[2], ..., g[m])` is said to
    be found in a sequence `s = (s[1], s[2], ..., s[n])` if there is a sequence
    of indexes of `s` with the same size of `g`, namely
    `z = (z[1], z[2], ..., z[m])`, such that:

    - `∀ i, 1 <= i <= m`, g[i] matches with s[z[i]]
    - `∀ i, 1 <= i < m`, z[i] < z[i+1]
    - `∀ i, 1 <= i < m`, z[i+1] - z[i] <= k + 1

    In other words, a skipgram is in a sequence if there is a subsequence
    allowing elements to be skipped that matches the skipgram. In this sense, k
    if the maximum number of elements skipped between two subsequent matched
    items of the sequence.

  `skipgrams`
  The initial set of skipgrams. Empty by default. One can add skipgrams using
  the `add()` method as well.

  `match_items`
    A function with the signature `match_items(seq_item, sg_item)` that must
    tell whether an item of the sequence (`seq_item`) matches with an item of
    the skipgram. By default the equality operator is used.


Examples
========

>>> seq = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']

Trivial example:
>>> finder = SkipgramFinder(k=3, skipgrams=[('the',)])
>>> list(finder.search(seq))
[((0,), ('the',)), ((6,), ('the',))]


Finding 3-skip-2-grams:
>>> finder = SkipgramFinder(k=3, skipgrams=[('the', 'fox'), ('the', 'dog')])
>>> list(finder.search(seq))
[((0, 3), ('the', 'fox')), ((6, 8), ('the', 'dog'))]


In the following example, only ('the', 'dog') is found because we allow only 1
skip:
>>> finder = SkipgramFinder(k=1, skipgrams=[('the', 'fox'), ('the', 'dog')])
>>> list(finder.search(seq))
[((6, 8), ('the', 'dog'))]


Using a different matcher:
>>> finder = SkipgramFinder(k=2,
...                         skipgrams=[(frozenset({'the'}), frozenset({'fox', 'dog'}))],
...                         match_items=lambda seq_item, sg_item: seq_item in sg_item,
... )
>>> matches = finder.search(seq)
>>> [tuple(seq[idx] for idx in indices) for indices, _ in matches]
[('the', 'fox'), ('the', 'dog')]


The same element of a sequence can be present in multiple matches. Example:
>>> finder = SkipgramFinder(k=7, skipgrams=[('the', 'fox'), ('the', 'dog')])
>>> list(finder.search(seq))
[((0, 3), ('the', 'fox')), ((0, 8), ('the', 'dog')), ((6, 8), ('the', 'dog'))]

"""

import collections

null_item = object()

class SkipgramFinder:
    def __init__(self, k=5, skipgrams=[], match_items=None):
        if match_items:
            self.match_items = match_items

        self.k = k

        self.__trie = {}

        for sg in skipgrams:
            self.add(sg)

    def add(self, sg):
        sg = list(sg)
        node = self.__trie
        for i, sg_item in enumerate(sg):
            if sg_item not in node:
                node[sg_item] = {}
            node = node[sg_item]
        node[null_item] = None

    def match_items(self, seq_item, sg_item):
        return seq_item == sg_item

    def find_matching_children(self, node, seq, item_idx):
        node = node if node is not None else self.__trie
        for sg_item in node:
            if self.match_items(seq[item_idx], sg_item):
                yield sg_item, node[sg_item]

    def node_has_null_child(self, node):
        return null_item in node

    def node_is_terminal(self, node):
        return self.node_has_null_child(node) and len(node) == 1

    def search(self, seq):
        seq = list(seq)
        max_dist = self.k + 1

        # candidates_buffer is a deque of at most max_dist elements. Each
        # element is a list of candidates. The i-th list gathers candidates
        # that are at a distance of max_dist - i.
        candidates_buffer = collections.deque([], max_dist)

        for item_idx in range(len(seq)):
            new_candidates = []
            # Look for candidates starting a match
            matching = self.find_matching_children(None, seq, item_idx)
            for sg_item, new_node in matching:
                cand = (item_idx,), (sg_item,), new_node
                new_candidates.append(cand)

            # Extend existing candidates
            for cand_list in candidates_buffer:
                for cand in cand_list:
                    indices, sg_items, node = cand
                    matching = self.find_matching_children(node, seq, item_idx)
                    for sg_item, new_node in matching:
                        new_candidates.append((
                            indices + (item_idx,),
                            sg_items + (sg_item,),
                            new_node,
                        ))


            # Look for matches in the new candidates: those that have reached
            # the end of a skipgram. While at it, remove candidates that have
            # reached terminal nodes.
            pruned_candidates = []
            for cand in new_candidates:
                indices, sg, node = cand
                if self.node_has_null_child(node):
                    yield indices, sg

                if self.node_is_terminal(node):
                    continue

                pruned_candidates.append(cand)

            # By appending the new list of candidates, the candidates that are
            # at distance max_dist are removed from the buffer.
            candidates_buffer.append(pruned_candidates)
