import pickle
import pathlib

import skemb.skipgram


class Node:
    def __init__(self):
        self.__children_map = {}
        self.__children = {}
        self.__has_null_child = False

    def add_child(self, sg_item):
        if sg_item in self.__children:
            return self.__children[sg_item]

        self.__children[sg_item] = Node()

        for attr in sg_item:
            if attr not in self.__children_map:
                self.__children_map[attr] = []
            self.__children_map[attr].append(sg_item)

        return self.__children[sg_item]

    def find_matching_children(self, seq_item):
        candidates = []
        for attr in seq_item:
            new_candidates = []
            for i, cand in enumerate(candidates):
                if cand is not None:
                    sg_item, remaining = cand
                    if attr in remaining:
                        remaining.remove(attr)
                    if not remaining:
                        yield sg_item, self.__children[sg_item]
                    else:
                        new_candidates.append((sg_item, remaining))

            for sg_item in self.__children_map.get(attr, tuple()):
                remaining = set(sg_item)
                remaining.remove(attr)
                if not remaining:
                    yield sg_item, self.__children[sg_item]
                else:
                    new_candidates.append((sg_item, remaining))

            candidates = new_candidates

    def add_null_child(self):
        self.__has_null_child = True

    def has_null_child(self):
        return self.__has_null_child

    def is_leaf(self):
        return len(self.__children) == 0


class SGPatterns(skemb.skipgram.SkipgramFinder):
    def __init__(self, k, patterns):
        self.patterns = set()
        self.__trie = Node()
        super().__init__(k, patterns)

    def __iter__(self):
        return iter(self.patterns)

    def __len__(self):
        return len(self.patterns)

    def find_matching_children(self, node, seq, item_idx):
        node = node if node is not None else self.__trie
        yield from node.find_matching_children(seq[item_idx])

    def node_has_null_child(self, node):
        return node.has_null_child()

    def node_is_terminal(self, node):
        return node.has_null_child() and node.is_leaf()

    def add(self, sg):
        sg = tuple(sg)
        self.patterns.add(sg)
        node = self.__trie
        for sg_item in sg:
            node = node.add_child(sg_item)

        node.add_null_child()

    def save(self, path):
        data = {'k': self.k, 'patterns': self.patterns}
        with open(path, 'wb') as f:
            pickle.dump(data, f)


def load(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return SGPatterns(data['k'], data['patterns'])
