# -*- encoding: utf-8 -*-
# Code adapted from: https://github.com/github-pengge/hierarchical_kmeans
import numpy as np
from scipy.cluster.vq import *
import pickle


class tree(object):
    def __init__(self, name, data=None, additional_data=None, children=None):
        self.name = name
        self.data = data
        self.additional_data = additional_data
        if children is None:
            self.children = []
        else:
            self.children = children

    def set_data(self, data):
        self.data = data

    def get_data(self):
        return self.data

    def set_additional_data(self, additional_data):
        self.additional_data = additional_data

    def get_additional_data(self):
        return self.additional_data

    def get_name(self):
        return self.name

    def set_children(self, children):
        self.children = children

    def add_child(self, child):
        self.children.append(child)

    def add_children(self, children):
        self.children.extend(children)

    def get_children_number(self):
        return len(self.children)

    def gather_data(self, gather_additional_data=False):
        '''Gather data from its children, depth=1'''
        if gather_additional_data:
            return [child.get_additional_data() for child in self.children]
        else:
            return [child.get_data() for child in self.children]

    def find(self, data, find_additional_data=False):
        if find_additional_data:
            if self.additional_data == data:
                return self
            else:
                for child in self.children:
                    res = child.find(data, find_additional_data)
                    if res != -1:
                        return res
                return -1
        else:
            if self.data == data:
                return self
            else:
                for child in self.children:
                    res = child.find(data, find_additional_data)
                    if res != -1:
                        return res
                return -1

    def is_leaf_node(self):
        return len(self.children) == 0

    def gather_leaves(self, output):
        assert type(output) == list, 'output must be a list.'
        if self.is_leaf_node():
            output.append(self)
            return
        for child in self.children:
            child.gather_leaves(output)

    def gather_data_from_leaves(self, output, gather_additional_data=False):
        assert type(output) == list, 'output must be a list.'
        if gather_additional_data:
            if self.is_leaf_node():
                output.append(self.additional_data)
                return
            for child in self.children:
                child.gather_data_from_leaves(output, gather_additional_data)
        else:
            if self.is_leaf_node():
                output.append(self.data)
                return
            for child in self.children:
                child.gather_data_from_leaves(output, gather_additional_data)

    def __str__(self):
        return str(self.name)

    def __repr__(self):
        return str(self)

    def __iter__(self):
        return iter(self.children)

    def __hash__(self):
        return hash(str(self.name))


class hierarchical_kmeans(object):
    def __init__(self, clusters):
        self.clusters = clusters  # tree type variable
        self.root = tree('root')
        self.total_k = 0

    def cluster(self, data, only_store_id=True, iteration=1):
        data = np.asarray(data)
        self.only_store_id = only_store_id
        cluster = self.clusters
        current_node = self.root
        idx = np.arange(data.shape[0])
        print('Doing hierarchical kmeans...')
        self._cluster(data, idx, current_node, cluster,
                      iteration=iteration, only_store_id=only_store_id)
        print('Done! Actual total number of clusters is %d.' % self.total_k)

    def _cluster(self, data, idx, current_node, cluster_node,
                 iteration=1, only_store_id=True):
        if cluster_node.is_leaf_node():
            self.total_k += 1
            if only_store_id:
                current_node.set_additional_data(idx)
            else:
                current_node.set_additional_data(data[idx])
            return
        
        children_number = min(len(data[idx]),
                              cluster_node.get_children_number())
        codebook, _ = kmeans(data[idx], children_number, iter=iteration)
        ids, _ = vq(data[idx], codebook)

        if codebook.shape[0] == 1:  # kmeans only find one cluster
            current_node.set_data(codebook[0])
            if only_store_id:
                current_node.set_additional_data(idx)
            else:
                current_node.set_additional_data(data[idx])
            return

        for i, c in enumerate(cluster_node):
            # cluster number smaller than the number of next branch
            if i >= codebook.shape[0]:
                break
            idx_i = idx[np.nonzero(ids == i)[0]]
            branch = tree(c.get_name(), codebook[i])
            current_node.add_child(branch)
            self._cluster(data, idx_i, branch, c, iteration=iteration,
                          only_store_id=only_store_id)

    def save(self, file_name):
        with open(file_name, 'wb') as f:
            pickle.dump([self.clusters, self.root, self.total_k], f)

    @staticmethod
    def load(file_name):
        hk = hierarchical_kmeans(None)
        with open(file_name, 'rb') as f:
            [hk.clusters, hk.root, hk.total_k] = pickle.load(f)
        return hk

    def find_cluster(self, data, max_depth=-1):
        data = np.asarray(data)
        if len(data.shape) == 1:
            return self._find_cluster([data], self.root, 0, max_depth)
        return [self._find_cluster([d], self.root, 0, max_depth) for d in data]

    def _find_cluster(self, data, current_node, current_depth, max_depth=-1):
        if (current_node.is_leaf_node() or
            (max_depth >= current_depth and max_depth > 0)):
            return current_node
        codebook = current_node.gather_data()
        id, _ = vq(data, codebook)
        current_depth += 1
        branch = current_node.children[id[0]]
        # print('id: %d' % id[0], 'branch: %s' % branch)
        return self._find_cluster(data, branch, current_depth, max_depth)
        

def test_tree():
    print('-' * 20)
    t = tree('root', 0)
    d = []
    x = tree('level1,node@1', 1)
    y = tree('level1,node@2', 2)
    z = tree('level1,node@3', 3)
    # t.add_children([x,y,z])
    t.add_child(x)
    t.add_child(y)
    t.add_child(z)
    x.add_child(tree('level2,child@1,node@1', 4))
    d.append(4)
    x.add_child(tree('level2,child@1,node@2', 5))
    d.append(5)
    y.add_child(tree('level2,child@2,node@1', 6))
    d.append(6)
    d.append(z.get_data())
    print(t.children)
    print(t.children[0].children)
    print(t.find(6))
    print(t.is_leaf_node())  # False
    print(t.gather_data())
    data = []
    t.gather_data_from_leaves(data)
    assert data == d
    for i, child in enumerate(t):
        print(i, child)
    leaves = []
    t.gather_leaves(leaves)
    print(leaves)
    print('Passed.')


def test_hk():
    print('-' * 20)
    data = np.random.randn(2000, 10)
    clusters = tree('root', 0)
    x = tree('l1-c1@1', 1)
    y = tree('l1-c2@2', 2)
    z = tree('l1-c3@3', 3)
    clusters.add_child(x)
    clusters.add_child(y)
    clusters.add_child(z)
    z = tree('l2-x@1', 3)
    w = tree('l2-x@2', 2)
    p1 = tree('l2-y@1', 2)
    p2 = tree('l2-y@2', 2)
    p3 = tree('l2-y@3', 2)
    x.add_children([z,w])
    y.add_children([p1,p2,p3])
    hk = hierarchical_kmeans(clusters)
    hk.cluster(data)
    hk.save('hk.pkl')
    print(clusters.children[0].children)
    print(hk.root.children[0].children)
    hk = hierarchical_kmeans.load('hk.pkl')
    print(hk.root.children[0].children,
        len(hk.root.children[0].children[0].additional_data))
    print('data[0] clustered to node %s' % hk.find_cluster(data[0]))
    print(0 in hk.root.children[0].children[0].get_additional_data())
    # when hk.find_cluster(data[0])=='l2-x@1', this must be True.
    print("Passed? Run several times and check last output."
          " When last but two output is \"l2-x@1\", it must be True.")


def test_add_tree():
    cluster_structure = tree('my_cluster_structure')
    for i in range(2):
        x = tree('l1-c%d' % (i+1))
        cluster_structure.add_child(x)
        for j in range(2):
            y = tree('l2-c%d-c%d' % (i+1, j+1))
            x.add_child(y)
            for k in range(2):
                z = tree('l3-c%d-c%d-c%d' % (i+1, j+1, k+1))
                y.add_child(z)
    
    # H-K-Means.
    hk = hierarchical_kmeans(cluster_structure)

    training_data = np.random.randn(600, 10)
    hk.cluster(training_data, only_store_id=True, iteration=1)

    testing_data = np.random.randn(2, 10)
    codes = hk.find_cluster(testing_data, max_depth=-1)
    print(type(codes[0]))
    print(codes[0], codes[0].name, codes[0].data)
    print(codes[0].get_additional_data())
    print(np.argwhere(codes[0].data==training_data))
    print(np.argmin(codes[0].data-training_data))
    
    return hk


def recursively_adding_cluster(node, num_clusters, curr_lvl=0, curr_id=""):
    for i in range(num_clusters[curr_lvl]):
        to_next_id = curr_id + "-c{}".format(i+1)
        new_id = "l{}".format(curr_lvl+1) + to_next_id
        t = tree(new_id)
        node.add_child(t)

        if curr_lvl == len(num_clusters) - 1:
            pass
        else:
            next_lvl = curr_lvl + 1
            recursively_adding_cluster(t, num_clusters,
                                       curr_lvl=next_lvl,
                                       curr_id=to_next_id)
    return


def get_hkmeans_clusters(data, num_clusters=[], iteration=1):
    assert len(num_clusters) > 0, "num_clusters is empty!"
    num_clusters = np.asarray(num_clusters)
    total_clusters = np.prod(num_clusters)
    assert data.shape[0] >= total_clusters, ("Number of data ({}) should >"
        " total possible clusters ({})".format(data.shape[0], total_clusters))
    pass

    cluster_structure = tree('my_cluster_structure')
    recursively_adding_cluster(cluster_structure, num_clusters, 0, "")

    # H-K-Means.
    hk = hierarchical_kmeans(cluster_structure)
    hk.cluster(data, only_store_id=True, iteration=iteration)

    return hk


def init_hkmeans_clusters(data=None, num_clusters=[]):
    assert len(num_clusters) > 0, "num_clusters is empty!"
    num_clusters = np.asarray(num_clusters)
    total_clusters = np.prod(num_clusters)
    if data is not None:
        assert data.shape[0] >= total_clusters, ("Number of data ({}) should >"
            " total possible clusters ({})".format(data.shape[0],
                                                   total_clusters))

    cluster_structure = tree('my_cluster_structure')
    recursively_adding_cluster(cluster_structure, num_clusters, 0, "")

    # H-K-Means.
    hk = hierarchical_kmeans(cluster_structure)

    print("Constructed hkmeans cluster structure with {}".format(num_clusters))

    return hk


if __name__ == '__main__':
    # test_tree()
    # test_hk()
    hk = test_add_tree()

    data = np.random.randn(60000, 100)
    hk = init_hkmeans_clusters(data, [12, 12, 12])
    hk = get_hkmeans_clusters(data, [12, 12, 12, 12], iteration=1)
