"""
Clustering
"""
import warnings

warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import io
from datetime import datetime
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
from sklearn import cluster
from sys import exit

DELIMITER = '\t'


def _get_distance_from_centroid(temp_emb, centroid):
    """
    :param temp_emb: embedding of the template
    :param centroid: centroid of the cluster
    :return: Get distance from template embedding (a list or coordinates) to centroid (an array of coordinates)
    """
    if len(temp_emb) != len(centroid):
        print('ERROR: {} embedding dimensions and {} centroid dimensions'.format(len(temp_emb), len(centroid)))
        exit()
    return np.sqrt(sum([(centroid[i] - temp_emb[i]) ** 2 for i in range(len(centroid))]))


def _get_centroid(res, all_templates, num_clusters, emb_dim):
    """
    :param res: res is a k-len list of the cluster id corresponding to each template
    :param all_templates: each row corresponding to the embedding of a template, including meta information
    :param num_clusters: number of clusters
    :param emb_dim: embedding dimension
    :return: Generate a list of clusters, each of which is a list of templates belonging to the
    cluster sorted in ascending order by their distance from the cluster centroid. Here, centroids
    is an n-len array of 100-len arrays, each of which is the centroid for that cluster
    """
    clusters = [[] for _ in range(num_clusters)]
    start = len(all_templates[0]) - (emb_dim)
    for i in range(len(res) - 1):
        cur_cluster_id = int(res[i + 1])
        cur_template = all_templates[i + 1]
        # In the list corresponding with the template, only keep the embedding
        new_template = cur_template[start:]
        clusters[cur_cluster_id].append(new_template)
    centroid = [[np.mean([clusters[id][templateId][indx] for templateId in range(len(clusters[id]))]) \
                 for indx in range(emb_dim)] for id in range(len(clusters))]

    return centroid


def cluster_embedding(templates_data, templates_emb, fields, core_field_names,
                        emb_dim, out_fname, num_clusters, emb_mode, alg_name):
    # This function returns a list of clusters
    # Each element is a list of the templates assigned to the cluster that has this index,
    # sorted in ascending order by their distance from the cluster centroid
    # Each template in the inner list is a list of ['Response', 'Template', 'Action',
    # 'Mode', *template_embedding]
    # -name: the algorithm name
    clustering_names = ['kmeans', 'AffinityPropagation', 'SpectralClustering', 'Ward', 'AgglomerativeClustering',
                        'DBSCAN', 'Birch']

    if alg_name not in clustering_names:
        print('The algorithm {} is not supported. Skipping clustering'.format(alg_name))
        cluster_name_str = "\t".join(clustering_names)
        print("choose one of the following algorithms: {}".format(cluster_name_str))
    else:
        print("Using algorithm {} for clustering".format(alg_name))

    assert len(templates_data) == len(templates_emb)

    out_file = open(out_fname, 'w')
    header = fields
    out_file.write(DELIMITER.join(header) + '\n')

    # templates_emb is a k+1-len list of emb-dim+meta-fields lists, with a header element at 0,
    # and each subsequent element corresponding to a template (response, template, action, mode,
    # 100 embedding dimensions)
    df = pd.DataFrame(templates_emb)
    df.columns = df.iloc[0]
    df = df[1:]
    X = df[core_field_names].values
    X = StandardScaler().fit_transform(X)

    print("finish standard scaling")
    from sklearn import cluster
    # create clustering estimators
    two_means = cluster.KMeans(n_clusters=num_clusters, random_state=0)
    spectral = cluster.SpectralClustering(n_clusters=num_clusters, eigen_solver='arpack', affinity="nearest_neighbors")
    dbscan = cluster.DBSCAN(eps=.2)
    affinity_propagation = cluster.AffinityPropagation(damping=.9, preference=-200)
    birch = cluster.Birch(n_clusters=num_clusters)
    ward = None
    average_linkage = None
    if alg_name == 'Ward' or alg_name == 'AgglomerativeClustering':
        # connectivity matrix for structured Ward
        connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
        print("finish construct connectivity")
        # make connectivity symmetric
        connectivity = 0.5 * (connectivity + connectivity.T)
        ward = cluster.AgglomerativeClustering(n_clusters=num_clusters, linkage='ward', connectivity=connectivity)
        average_linkage = cluster.AgglomerativeClustering(linkage="average", affinity="cityblock",
                                                          n_clusters=num_clusters, connectivity=connectivity)
    print("finish initialization")

    clustering_algorithms = [two_means, affinity_propagation, spectral, ward, average_linkage, dbscan, birch]

    for name, algorithm in zip(clustering_names, clustering_algorithms):
        if alg_name != name:
            print(name)
            continue
        algorithm.fit(X)
        if hasattr(algorithm, 'labels_'):
            res = algorithm.labels_.astype(np.int).tolist()
        else:
            res = algorithm.predict(X).tolist()
        print("finish training")
        res.insert(0, 'ClusterId')
        # res is a k-len list of the cluster id corresponding to each template, where k = number of templates)
        assert len(templates_data) == len(templates_emb)
        assert len(templates_emb) == len(res)

        all_templates, all_templates_no_emb = ([] for _ in range(2))
        for meta, templ, cluster in zip(templates_data, templates_emb, res):
            row = []
            row.extend(meta)
            row.append(cluster)
            all_templates_no_emb.append(row[:])
            row.extend(templ)
            all_templates.append(row)

        if hasattr(algorithm, 'cluster_centers_'):
            centroids = algorithm.cluster_centers_
            # centroids is an n-len array of 100-len arrays, each of which is the centroid for that cluster
        else:
            centroids = _get_centroid(res, all_templates, num_clusters, emb_dim)

        print("finish collecting centroid")
        if len(res) != len(templates_emb):
            print('ERROR: {} templates and {} cluster values'.format(len(res), len(templates_emb)))
            exit()

        # each element in this list will be a list of the templates
        # assigned to the cluster that has this index
        clusters = [[] for _ in range(num_clusters)]

        res_fields = []
        res_fields.extend(fields)
        res_fields.extend(['ClusterId', 'Distance'])

        start = len(all_templates[0]) - emb_dim
        # end = start + emb_dim - 1
        # For each cluster id, add the list of template embeddings
        # res and templates_emb start from 1 instead of 0
        assert len(all_templates) == len(all_templates_no_emb)
        all_templates_no_emb.insert(0, res_fields)
        for i in range(len(res) - 1):
            cur_cluster_id = int(res[i + 1])
            cur_template = all_templates[i + 1]
            # Get the distance from the template embedding to the cluster centroid
            distance = _get_distance_from_centroid(cur_template[start:], centroids[cur_cluster_id])
            # In the list corresponding with the template, remove embedding and append distance
            new_template = cur_template[:start - 1]
            new_template.extend([cur_cluster_id, distance])
            all_templates_no_emb[i + 1] = new_template
            # Add the template to the clusters dictionary, indexed by cluster id
            template_row = DELIMITER.join(str(x) for x in new_template)
            out_file.write(clean_row_text_for_file(template_row) + '\n')
            clusters[cur_cluster_id].append(new_template)

        # Sort the templates in ascending order by their distance from the centroid
        for cluster_id in range(len(clusters)):
            templates = clusters[cluster_id]
            distance_idx = start - 1
            sorted_templates = sorted(templates, key=lambda t: t[distance_idx], reverse=False)
            clusters[cluster_id] = sorted_templates

        out_file.close()
        return clusters, emb_dim, all_templates_no_emb
