import chromadb
from chromadb.utils import embedding_functions
from chromadb.config import Settings

embed_model = "all-MiniLM-L6-v2"
persist_directory = "chroma_dir"

# client = chromadb.Client(Settings(
#     chroma_db_impl="duckdb+parquet",
#     persist_directory=persist_directory # Optional, defaults to .chromadb/ in the current directory
# ))

client = chromadb.PersistentClient(path=persist_directory)


def index_chroma(collection_name, documents_list, metadatas_list, ids_list):
    print("Creating chroma db collection")
    print("collection name : ",collection_name)
    if embed_model == 'default':
        embed_model_function = embedding_functions.DefaultEmbeddingFunction()
    else:
        embed_model_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embed_model)
    try:
        client.delete_collection(name=collection_name)
        print("Deleted the existing collection :",collection_name)  
    except:
        pass
    print("Creating collection:", collection_name)
    collection = client.create_collection(name=collection_name, embedding_function=embed_model_function, metadata={"hnsw:space": "cosine"})
    collection.add(
            documents= documents_list, # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
            metadatas= metadatas_list, # filter on these!
            ids=ids_list, # unique for each doc 
            )
    print("documents indexed : ", len(documents_list))


def query_chroma(collection_name, query_texts, recall_limit=15):
    # print("Querying chroma db collection")
    # Create collection. get_collection, get_or_create_collection, delete_collection also available!
    if embed_model == 'default':
        embed_model_function = embedding_functions.DefaultEmbeddingFunction()
    else:
        embed_model_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embed_model)
    collection = client.get_collection(name=collection_name, embedding_function=embed_model_function)
    output = collection.query(
        query_texts=query_texts,
        n_results=recall_limit,
    )
    results = output.get('metadatas')
    query_class_relevance = []
    for res in results:
        rel_classes = [{r.get('path'): r.get('description') for r in res}]
        query_class_relevance.append(rel_classes)
    # print("Query count :", len(query_class_relevance))
    return query_class_relevance