import sys
import os

DIRPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
print(DIRPATH)
sys.path.append(DIRPATH)

from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import os
import torch
from src.util import get_config, print_config, read_text_file

cfg = get_config()
print_config(cfg)

# we will make the dataset, push the dataset, and use it, all within this script.
config = {
    "dataset": cfg["dataset"],
    "model": cfg["models"]["ft_model"],
    "max_seq_len": 4096,
    "lora_rank": 256,
    "lora_alpha": 256,
    "batch_size": 4,
    "epochs": 3,
    "lr_scheduler": "cosine_with_restarts",
    "lr": 2e-5,
    "warmup_steps": 10,
    "wandb_project_name": cfg["wandb"]["project_name"],
    "wandb_host": cfg["wandb"]["host"],
    "model_to_push": "",
    "should_i_push": False,
    "training_data_format": "generate_rules_from_intent",
}

run_name = (
    "rule_ft"
    + "_r_"
    + str(config["lora_rank"])
    + "_a_"
    + str(config["lora_alpha"])
    + "_e_"
    + str(config["epochs"])
)

# these are values which depend on the previous ones, so we calculate them here
extended_config = {
    "run_name": run_name,
    "model_outputs": DIRPATH + f"/out/models/{run_name}/",
}

# you shouldn't have to change these
standard_config = {"load_in_4bit": True}

# merge the two dictionaries
config = {**config, **extended_config, **standard_config}

prompts_folder = os.path.join(DIRPATH, "data", "prompts")
system_prompt = read_text_file(
    os.path.join(
        prompts_folder, "system_prompts", f"{config['training_data_format']}.txt"
    )
)
data_format = read_text_file(
    os.path.join(
        prompts_folder, "data_formats", f"{config['training_data_format']}.txt"
    )
)

os.environ["WANDB_HOST"] = os.getenv("WANDB_HOST", config.get("wandb_host"))
os.environ["WANDB_PROJECT"] = os.getenv(
    "WANDB_PROJECT", config.get("wandb_project_name")
)
os.environ["WANDB_NAME"] = os.getenv("WANDB_NAME", config.get("run_name"))


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=config.get("model"),
    max_seq_length=config.get("max_seq_len"),
    dtype=None,
    load_in_4bit=config.get("load_in_4bit"),
)


# okay, before everything, we first make the dataset
def make_trainer(data):
    metadata = data["metadata"]
    iput = data["input"]
    output = data["rule_strategy_1"]
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": data_format.format(metadata, iput)},
        {"role": "assistant", "content": output},
    ]
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )
    return {"text": text}


standard_dataset = load_dataset(config["dataset"])
standard_dataset = standard_dataset.map(make_trainer, batched=False)
standard_dataset.push_to_hub(config["dataset"] + "_rulegen_strategy_1_train")

model = FastLanguageModel.get_peft_model(
    model,
    r=config.get("lora_rank"),  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=config.get("lora_alpha"),
    lora_dropout=0,  # Supports any, but = 0 is optimized
    bias="none",  # Supports any, but = "none" is optimized
    use_gradient_checkpointing="unsloth",
    max_seq_length=config.get("max_seq_len"),
    random_state=3407,
    use_rslora=False,  # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
)

train_dataset = load_dataset(
    config.get("dataset") + "_rulegen_strategy_1_train", split="train"
)
EOS_TOKEN = tokenizer.eos_token

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=config.get("max_seq_len"),
    dataset_num_proc=2,
    packing=False,  # Can make training 5x faster for short sequences.
    args=TrainingArguments(
        per_device_train_batch_size=config.get("batch_size"),
        gradient_accumulation_steps=2,
        warmup_steps=config.get("warmup_steps"),
        num_train_epochs=config.get("epochs"),
        learning_rate=config.get("lr"),
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type=config.get("lr_scheduler"),
        seed=3407,
        save_strategy="epoch",
        output_dir=config.get("model_outputs"),
        dataloader_pin_memory=False,
        run_name=config.get("run_name"),
        report_to="wandb",  # enable logging to W&B
    ),
)

trainer_stats = trainer.train()

model.save_pretrained_merged(
    config.get("model_outputs"), tokenizer, save_method="merged_16bit"
)
if config.get("should_i_push"):
    model.push_to_hub_merged(
        config.get("model_to_push"), tokenizer, save_method="merged_16bit"
    )
