"""
Running Inference on the NaiveFT model
"""

import sys
import os

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from datasets import load_dataset
from vllm import LLM, SamplingParams
from src.util import (
    get_config,
    print_config,
    read_text_file,
    run_inference,
)
from transformers import AutoTokenizer
from torch.cuda import device_count

config = get_config()
print_config(config)

# === CONFIGURABILITY STARTS ===

NUM_GPUS = device_count()
TEMPERATURE = 1
MAX_TOKENS = 2048
TOP_P = 0.95

config = {
    "dataset": config["dataset"],
    "training_data_format": "generate_rules_from_intent",
    "generated_column_name": "rule_startegy_2_learned_256",
    "lora_alpha": 256,
    "lora_rank": 256,
    "epochs": 3,
    "gpu_memory_utilization": 0.9,
}

# === CONFIGURABILITY ENDS ===

run_name = (
    "rule_ft_strategy_2"
    + "_r_"
    + str(config["lora_rank"])
    + "_a_"
    + str(config["lora_alpha"])
    + "_e_"
    + str(config["epochs"])
)

prompts_folder = os.path.join(DIRPATH, "data", "prompts")
system_prompt = read_text_file(
    os.path.join(
        prompts_folder, "system_prompts", f"{config['training_data_format']}.txt"
    )
)
data_format = read_text_file(
    os.path.join(
        prompts_folder, "data_formats", f"{config['training_data_format']}.txt"
    )
)

model_name = DIRPATH + f"/out/models/{run_name}/"
llm = LLM(
    model_name,
    enable_prefix_caching=True,
    gpu_memory_utilization=config["gpu_memory_utilization"],
    tensor_parallel_size=NUM_GPUS,
)
sampling_params = SamplingParams(
    temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
tokenizer = llm.get_tokenizer()


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        metadata = row["metadata"]
        iput = row["input"]
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": data_format.format(metadata, iput)},
        ]
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        prompts.append(text)
    return prompts


def process_for_dataset_split(df, column_name=config["generated_column_name"]):
    print("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    print("Running Inference")
    outputs = run_inference(llm, sampling_params, prompts)
    print(len(outputs))
    # Check if the column already exists
    if column_name in df.column_names:
        df = df.remove_columns([column_name])
    df = df.add_column(column_name, outputs)
    return df


dataset = load_dataset(config["dataset"])
dataset["train"] = process_for_dataset_split(dataset["train"])
dataset["test"] = process_for_dataset_split(dataset["test"])
dataset.push_to_hub(config["dataset"])
