# uvicorn api_server:app --host 0.0.0.0 --reload

import asyncio
from typing import Dict, List, Optional

from config_api_server import CONFIG_API_SERVER, NORM_TYPE_API_SERVER, THRESHOLD_API_SERVER
from fastapi import FastAPI
from pydantic import BaseModel, Field
from utils.TSP_gen_call import *
from utils.TSP_gen_utils import *

app = FastAPI()


class GenerateRequest(BaseModel):
    messages_list: List[
        List[Dict]
    ]  # List of message lists for batch processing, renamed for clarity
    max_length: Optional[int] = Field(
        default=None
    )  # Optional maximum length, default is 1000
    max_new_tokens: Optional[int] = Field(
        default=50
    )  # New field for specifying maximum new tokens
    apply_chat_template: Optional[bool] = Field(default=False)
    # For early stopping
    until: Optional[List[str]] = Field(
        default=None
    )


@app.on_event("startup")
async def load_model():

    global model_actors_list, tokenizers, vocab_union, mapping_matrices, index_to_vocab, special_prefix_tokens_dict, byte_mappings_list, min_max_position_embeddings, model_name_list, primary_index, threshold

    (
        model_actors_list,
        tokenizers,
        vocab_union,
        mapping_matrices,
        index_to_vocab,
        special_prefix_tokens_dict,
        byte_mappings_list,
        min_max_position_embeddings,
        model_name_list,
        primary_index,
        threshold,
    ) = setup_model_actors_and_data(CONFIG_API_SERVER,NORM_TYPE_API_SERVER,THRESHOLD_API_SERVER)

@app.get("/status")
async def get_status():
    return {"status": "ready"}

# --- for cache ---
@app.get("/save_mmlu_cache")
async def save_mmlu_cache():
    cache_status = [
        model_actor.save_mmlu_cache.remote()
        for model_actor in model_actors_list
    ]
    ray.get(cache_status)
    return {"status": "ready"}
# -----------------

@app.post("/api/generate/")
async def api_generate(request: GenerateRequest):
    chat_list = request.messages_list 
    max_length = request.max_length
    max_new_tokens = request.max_new_tokens
    apply_chat_template = request.apply_chat_template
    until = request.until

    length_param = (
        {"max_length": max_length}
        if max_length is not None
        else {"max_new_tokens": max_new_tokens}
    )

    prepare_inputs = [
        model_actor.prepare_inputs_for_model.remote(
            chat_list, min_max_position_embeddings, apply_chat_template
        )
        for model_actor in model_actors_list
    ]
    models_inputs = ray.get(prepare_inputs)
    input_ids_0 = models_inputs[0]

    output = generate_ensemnble_response(
        model_actors_list=model_actors_list,
        model_name_list=model_name_list,
        tokenizers=tokenizers,
        vocab_union=vocab_union,
        mapping_matrices=mapping_matrices,
        index_to_vocab=index_to_vocab,
        special_prefix_tokens_dict=special_prefix_tokens_dict,
        byte_mappings_list=byte_mappings_list,
        primary_index=primary_index,
        threshold=threshold,
        until=until,
        **length_param,
    )

    generated_texts = extract_generated_texts(tokenizers[0], input_ids_0, output)
    logger.info(f"Generated text:{generated_texts}")

    return {"response": generated_texts}
