
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score


def create_groups_with_kmeans(X, min_clusters=40, max_clusters=200):

    best_score = -1
    best_n_clusters = min_clusters

    if X.ndim == 1:
        X = X.reshape(-1, 1)
        
    if max_clusters:
        max_clusters = max_clusters

    max_clusters = np.min([max_clusters, len(X)])
    cluster_options = np.linspace(min_clusters, max_clusters, num=10, dtype=int)

    for n_clusters in cluster_options:
        kmeans = KMeans(n_clusters=int(n_clusters), n_init="auto")
        labels = kmeans.fit_predict(X)
        score = silhouette_score(X, labels)
        if score > best_score:
            best_score = score
            best_n_clusters = n_clusters

    kmeans = KMeans(n_clusters=best_n_clusters, n_init="auto")
    labels = kmeans.fit_predict(X)

    return labels




def expand_data(X, z, groups, pz, loco):
    lengths = [len(sub_z) for sub_z in pz]
    Xall = np.repeat(X, lengths, axis=0)
    
    model_dummy_vars = np.concatenate([np.eye(length) for length in lengths], axis=0)
    
    unique_groups = np.unique(groups)
    group_dummy_dict = {group: i for i, group in enumerate(unique_groups)}
    group_dummy_vars = np.concatenate(
        [np.tile(np.eye(len(unique_groups))[group_dummy_dict[group]], (length, 1)) for group, length in zip(groups, lengths)],
        axis=0
    )
    
    if loco == 'all-in':
        Xall = np.hstack((Xall, model_dummy_vars, group_dummy_vars))
    elif loco == 'exclude-embeddings':
        Xall = np.hstack((model_dummy_vars, group_dummy_vars))
    elif loco == 'exclude-modeldummy':
        Xall = np.hstack((Xall, group_dummy_vars))
    elif loco == 'exclude-taskdummy':
        Xall = np.hstack((Xall, model_dummy_vars))
    else:
        raise ValueError('Invalid feature type specified!')
    
    try:
        zall = np.concatenate(z)
    except ValueError:
        zall = np.concatenate([x[:, 0] for x in z])
    
    pzall = np.concatenate(pz)
    groupsall = np.concatenate([[f"{group}_{j}" for j in range(length)] for group, length in zip(groups, lengths)])
    
    return np.array(Xall), np.array(zall), np.array(groupsall), np.array(pzall)
