import requests
import os
from typing import List, Tuple, Union, Optional, Dict
import numpy as np

from concurrent.futures import ThreadPoolExecutor, as_completed

def get_text_embeddings(texts: Union[str, List[str]], 
                        max_workers: int = 5, 
                        api_key: str = "",
                        embedding_cache: Optional[Dict[str, np.ndarray]] = None) -> Tuple[np.ndarray, int]:
    """
    Get text embeddings using OpenAI's text-embedding-3-small model.

    Args:
        texts (Union[str, List[str]]): A single text string or a list of text strings to embed.
        max_workers (int): The maximum number of workers for parallel processing.
        api_key (str): The API key for accessing OpenAI's services.
        embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings.

    Returns:
        Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage.
    """
    if not api_key:
        api_key = os.getenv("OPENAI_API_KEY")

    url = "https://api.openai.com/v1/embeddings"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

    def fetch_embedding(text: str) -> Tuple[str, np.ndarray, int]:
        if embedding_cache is not None and text in embedding_cache:
            return text, embedding_cache[text], 0  # Returning 0 tokens since no API call is made
        
        data = {
            "input": text,
            "model": "text-embedding-3-small"
        }

        response = requests.post(url, headers=headers, json=data)
        if response.status_code == 200:
            data = response.json()
            embedding = np.array(data["data"][0]["embedding"])
            token = data["usage"]["prompt_tokens"]
            return text, embedding, token
        else:
            response.raise_for_status()

    if isinstance(texts, str):
        _, embedding, tokens = fetch_embedding(texts)
        return np.array(embedding), tokens

    embeddings = []
    total_tokens = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(fetch_embedding, text): text for text in texts}

        for future in as_completed(futures):
            try:
                text, embedding, tokens = future.result()
                embeddings.append((text, embedding, tokens))
                total_tokens += tokens
            except Exception as e:
                print(f"An error occurred for text: {futures[future]}")
                print(e)

    # Sort results to match the order of the input texts
    embeddings.sort(key=lambda x: texts.index(x[0]))
    if embedding_cache is not None:
        for text, embedding, _ in embeddings:
            embedding_cache[text] = embedding 
    embeddings = [result[1] for result in embeddings] 

    print(f"text-embedding-3-small token used: {total_tokens}")
    return np.array(embeddings), total_tokens
