import sys
import os

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from src.util import get_config, print_config, read_text_file
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer


# get the configuration
config = get_config()
print_config(config)

PROMPT_ROOT = DIRPATH + "/data/prompts/"
SYS_PROMPT = PROMPT_ROOT + "system_prompts/baseline_generation.txt"
DATA_FORMAT = PROMPT_ROOT + "data_formats/baseline_generation.txt"


data = load_dataset(config["dataset"], split="train")

tokenizer = AutoTokenizer.from_pretrained(config["models"]["small_model"])
data_format = read_text_file(DATA_FORMAT)


train_test_split = data.train_test_split(test_size=config["test_set_size"])

split_dataset = DatasetDict(
    {"train": train_test_split["train"], "test": train_test_split["test"]}
)

split_dataset.push_to_hub(
    config["dataset"] + f"-{int(config['test_set_size'] * 100)}-split"
)

messages = [{"role": "system", "content": read_text_file(SYS_PROMPT)}]


def make_prompts_for_naive_ft_data(data):
    prompts = []
    for row in data:
        metadata = row["metadata"]
        iput = row["input"]
        output = row["output"]

        user_content = data_format.format(metadata, iput, output)
        prompt = messages + [
            {
                "role": "user",
                "content": user_content,
            },
            {"role": "assistant", "content": output},
        ]
        prompt = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=False
        )
        prompts.append(prompt)
    return prompts


def process_for_dataset_split(df, column_name="text"):
    print("Generating Prompts...")
    prompts = make_prompts_for_naive_ft_data(df)
    df = df.add_column(column_name, prompts)
    return df


data = load_dataset(config["dataset"] + f"-{int(config['test_set_size'] * 100)}-split")
data["train"] = process_for_dataset_split(data["train"])
data["test"] = process_for_dataset_split(data["test"])

data.push_to_hub(
    config["dataset"]
    + f"-{int(config['test_set_size'] * 100)}-split"
    + "-naiveft-training-data"
)
