import sys
import os

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from datasets import load_dataset
from vllm import LLM, SamplingParams
from src.util import perform_inference
from src.util.load_config import get_config, print_config
from src.util.load_prompt import read_text_file
from transformers import AutoTokenizer

import torch

config = get_config()
print_config(config)

PROMPT_ROOT = DIRPATH + "/data/prompts/"
SYS_PROMPT = PROMPT_ROOT + "system_prompts/llm_similarity_evaluation.txt"
EMAIL_FORMAT = PROMPT_ROOT + "data_formats/ground_truth_and_candidate.txt"
MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
msize = "8b"
GPU_MEMORY_UTILIZATION = 0.90
TEMPERATURE = 1.0
MAX_TOKENS = 2048
TOP_P = 0.95
NUM_GPUS = torch.cuda.device_count()

dataset_name = (
    config["working_organization"]
    + "/"
    + config["subset_name"]
    + f"-{int(config['test_set_size'] * 100)}-split"
)

evaluation_candidate = "gold_rule_email"
column_name = f"{evaluation_candidate}_llmeval_{msize}_raw"

llm = LLM(
    MODEL,
    enable_prefix_caching=True,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    tensor_parallel_size=NUM_GPUS,
)

sampling_params = SamplingParams(
    temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)

tokenizer = llm.get_tokenizer()

messages = [{"role": "system", "content": read_text_file(SYS_PROMPT)}]

email_format = read_text_file(EMAIL_FORMAT)


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        ground_truth = row["email"]
        candidate = row[evaluation_candidate]

        user_content = email_format.format(candidate, ground_truth)
        prompt = messages + [
            {
                "role": "user",
                "content": user_content,
            },
        ]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True
        )
        prompts.append(prompt)
    return prompts


def process_for_dataset_split(
    df, column_name=f"{evaluation_candidate}_llmeval_{msize}_raw"
):
    print("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    print("Running Inference")
    outputs = perform_inference.run_inference(llm, sampling_params, prompts)
    print(len(outputs))
    # Check if the column already exists
    if column_name in df.column_names:
        df = df.remove_columns([column_name])
    df = df.add_column(column_name, outputs)
    return df


dataset = load_dataset(dataset_name)
dataset["train"] = process_for_dataset_split(dataset["train"])
dataset["test"] = process_for_dataset_split(dataset["test"])
dataset.push_to_hub(dataset_name)
