import sys
import os
import logging
from datetime import datetime

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
logging.info(f"Directory path set to {DIRPATH}")
sys.path.append(DIRPATH)

from datasets import load_dataset
from vllm import LLM, SamplingParams
from src.util import run_inference, get_config, print_config, read_text_file

config = get_config()
print_config(
    config
)  # If this function uses print, consider changing it to logging inside the function definition

PROMPT_ROOT = DIRPATH + "/data/prompts/"
SYS_PROMPT = PROMPT_ROOT + "system_prompts/email_generation_with_rules.txt"
DATA_FORMAT = PROMPT_ROOT + "data_formats/email_generation_with_rules.txt"

DATASET = config["dataset"]
COLUMN_NAME = "large_model_with_generated_70b_rules"
RULES_COLUMN = "learned_rulegen_70b_rules"

MODEL = config["models"]["large_model"]
GPU_MEMORY_UTILIZATION = 0.90
MAX_NUM_SEQS = 64
TEMPERATURE = 1.0
MAX_TOKENS = 2048
TOP_P = 0.95
# NUM_GPUS = torch.cuda.device_count()
NUM_GPUS = 4

logging.info("Starting VLLM Model Load")

llm = LLM(
    MODEL,
    enable_prefix_caching=True,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    tensor_parallel_size=NUM_GPUS,
)
sampling_params = SamplingParams(
    temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
tokenizer = llm.get_tokenizer()

messages = [{"role": "system", "content": read_text_file(SYS_PROMPT)}]
data_format = read_text_file(DATA_FORMAT)


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        meta = row["metadata"]
        intent = row["input"]
        rules = row[RULES_COLUMN]

        user_content = data_format.format(meta, intent, rules)
        prompt = messages + [
            {
                "role": "user",
                "content": user_content,
            },
        ]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True
        )
        prompts.append(prompt)
    return prompts


def process_for_dataset_split(df, column_name=COLUMN_NAME):
    logging.info("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    logging.info("Running Inference")
    outputs = run_inference(llm, sampling_params, prompts)
    logging.info(f"Generated outputs for {len(outputs)} entries")
    # Check if the column already exists
    if column_name in df.column_names:
        df = df.remove_columns([column_name])
    df = df.add_column(column_name, outputs)
    return df


dataset = load_dataset(DATASET)
dataset["train"] = process_for_dataset_split(dataset["train"])
dataset["test"] = process_for_dataset_split(dataset["test"])
dataset.push_to_hub(DATASET)
