"""
This file pertains to generating a baseline, based on the intents that we have
"""

import sys
import os
import json

print(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from datasets import load_dataset
from vllm import LLM, SamplingParams
from src.utils import iutils
import torch

# Constants
DATASET = "preference-agents/preference-enron"
# MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"
NUM_GPUS = torch.cuda.device_count()
GPU_MEMORY_UTILIZATION = 0.90
MAX_NUM_SEQS = 64
TEMPERATURE = 1.0
MAX_TOKENS = 2048
TOP_P = 0.95
PROMPT_ROOT = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "..", "prompts")
)
DATA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data"))
SYS_PROMPT = PROMPT_ROOT + "/system_prompts/baseline_generation.txt"
EMAIL_FORMAT = PROMPT_ROOT + "/formats/baseline_generation.txt"


def read_text_file(file_path):
    with open(file_path, "r") as file:
        content = file.read()
    return content


email_format = read_text_file(EMAIL_FORMAT)

messages = [
    {"role": "system", "content": read_text_file(SYS_PROMPT)},
]

llm = LLM(
    MODEL,
    enable_prefix_caching=True,
    # enforce_eager=True,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    tensor_parallel_size=NUM_GPUS,
    # max_num_seqs=MAX_NUM_SEQS,
)
sampling_params = SamplingParams(
    temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
tokenizer = llm.get_tokenizer()
print("Loaded Model")


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        fr = row["from"]
        to = row["to"]
        date = row["date"]
        subject = row["subject"]
        previous_context = row["previous_context"]
        intent = row["cleaned_intent_Meta-Llama-3-70B-Instruct"]

        user_content = email_format.format(
            fr, to, date, subject, previous_context, intent
        )
        prompt = messages + [{"role": "user", "content": user_content}]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True
        )
        prompts.append(prompt)
    return prompts


def write_outputs_to_json_file(outputs, filename):
    """
    Writes the list of strings 'outputs' to a JSON file.

    Args:
    outputs (list of str): The list of strings to be written to the file.
    filename (str): The name of the file to write the outputs to.
    """
    try:
        with open(filename, "w") as file:
            json.dump(outputs, file, indent=4)
        print(f"Successfully wrote outputs to {filename}")
    except Exception as e:
        print(f"An error occurred while writing to the file: {e}")


def process_for_dataset_split(
    df, column_name="generated_baseline_" + MODEL.split("/")[-1]
):
    print("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    print("Running Inference")
    outputs = iutils.run_inference(llm, sampling_params, prompts)
    print(len(outputs))
    write_outputs_to_json_file(outputs, DATA_ROOT + "outputs.json")
    df = df.add_column(column_name, outputs)
    return df


dataset = load_dataset(DATASET, split="train")
dataset = process_for_dataset_split(dataset)

# push this to hf hub
try:
    dataset.save_to_disk(DATA_ROOT)
except:
    pass
dataset.push_to_hub(DATASET)
