import torch
from sklearn.cluster import KMeans
from datasets import load_dataset, concatenate_datasets

import transformers
import my_utils
from my_utils import *
from transformers import AutoTokenizer
import pickle
import os, json
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import numpy as np
import time



seed = 3407 # 117, 3407, 42
cluster_size = 100

def generate_semantic_ids(document_embeddings, c=100, cluster_size=10):
    # Helper function to perform clustering
    def cluster(docs, original_indices, prefix='', cluster_indices_dict=None, progress_bar=None):
        if cluster_indices_dict is None:
            cluster_indices_dict = {}

        # Update the progress bar
        if progress_bar is not None:
            progress_bar.update(len(docs))

        # Base case: if the number of docs is less than or equal to the threshold c
        if len(docs) <= c:
            cluster_id = prefix
            cluster_indices_dict[cluster_id] = original_indices
            return [prefix + '-' + str(i) for i in range(len(docs))], cluster_indices_dict

        # Apply k-means clustering
        kmeans = KMeans(n_clusters=min(len(docs), cluster_size), random_state=seed)  #NOTE: previous seed is 3407
        clusters = kmeans.fit_predict(docs)
        
        # Generate identifiers for each cluster
        identifiers = [''] * len(docs)
        for cluster_id in range(cluster_size):
            subcluster_indices = np.where(clusters == cluster_id)[0]
            subcluster_docs = docs[subcluster_indices]
            # print("ori\n", original_indices)
            # print("sub\n", subcluster_indices)
            subcluster_original_indices = original_indices[subcluster_indices]

            # Recursive call for the next level
            child_ids, cluster_indices_dict = cluster(subcluster_docs, subcluster_original_indices, prefix=prefix + str(cluster_id), cluster_indices_dict=cluster_indices_dict, progress_bar=progress_bar)
            for idx, child_id in zip(subcluster_indices, child_ids):
                identifiers[idx] = child_id

        return identifiers, cluster_indices_dict

    # Normalize embeddings if necessary and convert to numpy
    if isinstance(document_embeddings, torch.Tensor):
        document_embeddings = document_embeddings.numpy()

    # Initialize the progress bar
    total_docs = len(document_embeddings)
    with tqdm(total=total_docs, desc="Clustering Documents") as progress_bar:
        # Start clustering from the root with the original indices
        original_indices = np.arange(total_docs)
        cluster_ids, cluster_indices_dict = cluster(document_embeddings, original_indices, progress_bar=progress_bar)

    return cluster_ids, cluster_indices_dict


save_dir = f'./vitaminc_100k/cluster_subset_q100_newid_clustersize{cluster_size}/'
os.makedirs(save_dir, exist_ok=True)

# load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer('all-MiniLM-L6-v2')


### Get document features ####
def formatting_prompts_func(example):
    sentence = example['text']
    
    emb = model.encode(sentence) # n * 384
    return emb.reshape(-1, 384)



abstracts_combined = load_dataset('csv', data_files='./data/vitaminc_corpus_100k.csv', split='train')
print(f"Corpus of size {len(abstracts_combined)} loaded.")


batch_size = 10  # You can adjust the batch size depending on your memory capacity
print("Start to extract embedding from the mt5 model!")
abs_embedding = abstracts_combined.map(lambda example: {"feature": formatting_prompts_func(example)})
print("Concatenating into np array...")
abs_embedding_tensor = np.concatenate(abs_embedding["feature"], axis=0)

### kmeans clustering ###
print("Start to cluster the embeddings! cluster size: ", cluster_size)
start_time = time.time()
semantic_ids, semantic_dict = generate_semantic_ids(abs_embedding_tensor, c=cluster_size)
print("Clustering time: ", time.time() - start_time)
with open(save_dir+f'/semantic_dict_seed{seed}.txt', 'w') as file:
    file.write("\n".join([f"Key: {key}, Length of Value: {len(value)}" for key, value in semantic_dict.items()]))

pickle.dump([semantic_ids, semantic_dict], open(save_dir+f'/semantic_id_dict_seed{seed}.list', 'wb'))
print(semantic_dict.keys())

# import pdb; pdb.set_trace()
