import sys
import os

# Adding the project root to the system path to ensure proper module resolution
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from datasets import load_dataset
import re
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from typing import List
import logging
from src.utils.iutils import run_inference, read_text_file

# Constants defining links and local directory paths
DATASET_LINK = "preference-agents/enron-jeff-02"
MODEL_LINK = "casperhansen/llama-3-70b-instruct-awq"
LOCAL_DATASET_DIR = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "..", "data", "enron-jeff-02")
)

# Directories for system prompts and formats
PROMPT_ROOT = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "..", "prompts")
)
SYS_PROMPTS_DIR = os.path.join(PROMPT_ROOT, "system_prompts")
FORMAT_DIR = os.path.join(PROMPT_ROOT, "formats")
SYSTEM_PROMPT_EVAL = "system_prompt_eval.txt"
FORMAT_PROMPT = "eval_generation.txt"

# HF dataset column names
HF_COLUMN_TO_EVAL = "gold_generated_baseline_zeroshot"
HF_GENERATED_EVAL_NAME = "generated_eval_gold_rules"
COLUMN_NAME_POST_CLEAN = HF_GENERATED_EVAL_NAME + "_scores"


# Environment variables or default values
DATASET = os.getenv("SCRIPTVAR_DATASET", DATASET_LINK)
MODEL = os.getenv("SCRIPTVAR_MODEL", MODEL_LINK)
QUANTIZED = os.getenv("SCRIPTVAR_QUANTIZED", True)
NUM_GPUS = torch.cuda.device_count()

# Load dataset
dataset = load_dataset(DATASET)

llm = LLM(
    MODEL,
    enable_prefix_caching=True,
    gpu_memory_utilization=0.85,
    quantization="awq" if QUANTIZED else None,
    tensor_parallel_size=NUM_GPUS,
    max_num_seqs=16,
)
tokenizer = llm.get_tokenizer()

# Read prompts and email format
system_prompt_content = read_text_file(
    os.path.join(SYS_PROMPTS_DIR, SYSTEM_PROMPT_EVAL)
)
prompt_base = [{"role": "system", "content": system_prompt_content}]
email_format = read_text_file(os.path.join(FORMAT_DIR, FORMAT_PROMPT))


def make_prompts_for_data(data, email_to_evaluate=HF_COLUMN_TO_EVAL):
    prompts = []
    for row in data:
        ground_truth = row["content"]
        generated_email = row[email_to_evaluate]
        user_content = email_format.format(generated_email, ground_truth)
        prompt = prompt_base + [{"role": "user", "content": user_content}]
        prompts.append(
            tokenizer.apply_chat_template(
                prompt, tokenize=False, add_generation_prompt=True
            )
        )
    return prompts


def process_for_dataset_split(dataset, column_name=HF_GENERATED_EVAL_NAME):
    prompts = make_prompts_for_data(dataset)
    outputs = run_inference(
        llm, SamplingParams(temperature=1, max_tokens=4096, top_p=0.95), prompts
    )
    if column_name in dataset.column_names:
        dataset = dataset.remove_columns([column_name])
    dataset = dataset.add_column(column_name, outputs)
    return dataset


def parse_out_token(data_subset, column_name, tag="<score>") -> List[str]:
    end_tag = "</" + tag[1:]
    pattern = re.escape(tag) + r"\s*(.*?)\s*" + re.escape(end_tag)
    extracted_texts = []
    for row in data_subset:
        matches = re.findall(pattern, str(row[column_name]), re.DOTALL)
        extracted_texts.append(matches[0] if matches else "-1")
    return extracted_texts


def clean_and_add_column(dataset, column_name):
    outputs = parse_out_token(dataset, HF_GENERATED_EVAL_NAME, "<score>")
    dataset = (
        dataset.remove_columns([column_name])
        if column_name in dataset.column_names
        else dataset
    )
    dataset = dataset.add_column(column_name, [{"scores": v} for v in outputs])
    return dataset


# Apply evaluation and cleaning
for split in ["train", "test"]:
    dataset[split] = process_for_dataset_split(dataset[split])
    dataset[split] = clean_and_add_column(dataset[split], COLUMN_NAME_POST_CLEAN)

# Save and Push to Hub
try:
    dataset.save_to_disk(LOCAL_DATASET_DIR)
except:
    print("Failed to save dataset to disk")
    pass
dataset.push_to_hub(DATASET_LINK)
