import sys
import os

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from datasets import load_dataset
from vllm import LLM, SamplingParams
from src.util import run_inference
from src.util import get_config, print_config
from src.util import read_text_file
from transformers import AutoTokenizer

import torch


config = get_config()
print_config(config)

PROMPT_ROOT = DIRPATH + "/data/prompts/"
SYS_PROMPT = PROMPT_ROOT + "system_prompts/rule_generator.txt"
DATA_FORMAT = PROMPT_ROOT + "data_formats/baseline_generation.txt"
MODEL = (
    DIRPATH
    + f'/out/models/{config["wandb"]["run_name"] + "-ruleft-" + str(config["training_parameters"]["lora_rank"])}'
    
)

GPU_MEMORY_UTILIZATION = 0.90
MAX_NUM_SEQS = 64
TEMPERATURE = 1.0
MAX_TOKENS = 2048
TOP_P = 0.95
NUM_GPUS = torch.cuda.device_count()

dataset_name = config["dataset"] + f"-{int(config['test_set_size'] * 100)}-split"

llm = LLM(
    MODEL,
    enable_prefix_caching=True,
    # enforce_eager=True,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    tensor_parallel_size=NUM_GPUS,
    # max_num_seqs=MAX_NUM_SEQS,
)
sampling_params = SamplingParams(
    temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
tokenizer = llm.get_tokenizer()

messages = [{"role": "system", "content": read_text_file(SYS_PROMPT)}]

data_format = read_text_file(DATA_FORMAT)


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        metadata = row["metadata"]
        iput = row["input"]

        user_content = data_format.format(metadata, iput)
        prompt = messages + [
            {
                "role": "user",
                "content": user_content,
            },
        ]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True
        )
        prompts.append(prompt)
    return prompts


def process_for_dataset_split(df, column_name="learned_rulegen_70b_rules"):
    print("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    print("Running Inference")
    outputs = run_inference(llm, sampling_params, prompts)
    print(len(outputs))
    # Check if the column already exists
    if column_name in df.column_names:
        df = df.remove_columns([column_name])
    df = df.add_column(column_name, outputs)
    return df


dataset = load_dataset(dataset_name)
dataset["train"] = process_for_dataset_split(dataset["train"])
dataset["test"] = process_for_dataset_split(dataset["test"])
dataset.push_to_hub(dataset_name)
