"""
Generates a naive finetuned model for the given dataset. First, we generate the finetuning dataset
"""

import sys
import os
import json

print(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
)

from src.utils import genutils
from datasets import load_dataset
from datasets import DatasetDict
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
import os
import torch

DATASET = os.getenv("SCRIPTVAR_DATASET", "preference-agents/preference-enron-42K")
NEW_DATASET = os.getenv(
    "SCRIPTVAR_NEW_DATASET", "preference-agents/naive-training-data-phi-3"
)
MODEL = os.getenv("SCRIPTVAR_BASE_MODEL", "microsoft/Phi-3-mini-4k-instruct")
NAIVE_FINETUNED_MODEL = "preference-agents/naive-fineturned-phi-3"
PROMPT_ROOT = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "..", "..", "prompts")
)
DATA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data"))
SYS_PROMPT = PROMPT_ROOT + "/system_prompts/baseline_generation.txt"
EMAIL_FORMAT = PROMPT_ROOT + "/formats/baseline_generation.txt"

dataset = load_dataset(DATASET)
tokenizer = AutoTokenizer.from_pretrained(MODEL)

email_format = genutils.read_text_file(EMAIL_FORMAT)

messages = [
    {"role": "user", "content": genutils.read_text_file(SYS_PROMPT)},
]


def make_prompts_for_data(data):
    prompts = []
    for row in data:
        fr = row["from"]
        to = row["to"]
        date = row["date"]
        subject = row["subject"]
        previous_context = row["previous_context"]
        intent = row["user_intent"]
        to_generate = row["email"]

        user_content = email_format.format(
            fr, to, date, subject, previous_context, intent
        )
        prompt = [
            {
                "role": "user",
                "content": genutils.read_text_file(SYS_PROMPT) + "\n\n" + user_content,
            },
            {"role": "assistant", "content": to_generate},
        ]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=False
        )
        prompts.append(prompt)
    return prompts


def process_for_dataset_split(df, column_name="text"):
    print("Generating Prompts...")
    prompts = make_prompts_for_data(df)
    df = df.add_column(column_name, prompts)
    return df


dataset = load_dataset(DATASET, split="train")
dataset = process_for_dataset_split(dataset)

# drop all columns, except the text column
dataset = dataset.remove_columns(
    [
        "from",
        "to",
        "date",
        "subject",
        "previous_context",
        "user_intent",
        "email",
        "baseline_generation",
    ]
)

# Split the dataset into 80% training and 20% testing
train_test_split = dataset.train_test_split(test_size=0.2)

# Create a DatasetDict to organize splits
split_datasets = DatasetDict(
    {"train": train_test_split["train"], "test": train_test_split["test"]}
)


split_datasets.push_to_hub(NEW_DATASET)

# dataset.push_to_hub(NEW_DATASET)
