import os
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
from dotenv import load_dotenv
from uuid import uuid4
import os
import torch
import wandb

load_dotenv()


def ensure_env(var):
    """Ensures an environment variable is set.

    Args:
        var: A string representing the name of the environment variable.

    Returns:
        The value of the environment variable if it is set.

    Raises:
        ValueError: If the environment variable is not set.
    """
    value = os.environ.get(var)
    if value is None:
        raise ValueError(f"You must set the {var} environment variable.")
    return value


os.environ["WANDB_HOST"] = os.getenv("WANDB_HOST", "stargate")

experiment_id = str(uuid4()).split("-")[0]
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "enron_naive_ft_42k")
os.environ["WANDB_NOTES"] = f"Experiment ID: {experiment_id}"

config = {
    "dataset": ensure_env("SCRIPTVAR_DATASET"),
    "model": ensure_env("SCRIPTVAR_MODEL"),
    "max_seq_len": int(ensure_env("SCRIPTVAR_MAX_SEQ_LEN")),
    "lora_rank": int(ensure_env("SCRIPTVAR_LORA_RANK")),
    "lora_alpha": int(ensure_env("SCRIPTVAR_LORA_ALPHA")),
    "batch_size": int(ensure_env("SCRIPTVAR_BATCH_SIZE")),
    "epochs": int(ensure_env("SCRIPTVAR_EPOCHS")),
    "lr_scheduler": ensure_env("SCRIPTVAR_LR_SCHEDULER"),
    "lr": float(ensure_env("SCRIPTVAR_LR")),
    "warmup_steps": int(ensure_env("SCRIPTVAR_WARMUP_STEPS")),
    "save_steps": int(ensure_env("SCRIPTVAR_SAVE_STEPS")),
    "model_outputs": ensure_env("SCRIPTVAR_MODEL_OUTPUTS"),
    "model_name": ensure_env("SCRIPTVAR_MODEL").split("/")[-1],
}

print(config)

load_in_4bit = True
save_path = os.path.join(
    config.get("model_outputs"), config.get("model_name"), experiment_id
)
os.makedirs(save_path, exist_ok=True)


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/llama-3-8b-bnb-4bit",
    max_seq_length=8192,
    dtype=None,
    load_in_4bit=load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=config.get("lora_rank"),  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=config.get("lora_alpha"),
    lora_dropout=0,  # Supports any, but = 0 is optimized
    bias="none",  # Supports any, but = "none" is optimized
    use_gradient_checkpointing="unsloth",
    max_seq_length=config.get("max_seq_len"),
    random_state=3407,
    use_rslora=False,  # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
)

train_dataset = load_dataset(config.get("dataset"), split="train")
EOS_TOKEN = tokenizer.eos_token

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=config.get("max_seq_len"),
    dataset_num_proc=2,
    packing=False,  # Can make training 5x faster for short sequences.
    args=TrainingArguments(
        per_device_train_batch_size=config.get("batch_size"),
        gradient_accumulation_steps=4,
        warmup_steps=config.get("warmup_steps"),
        num_train_epochs=config.get("epochs"),
        learning_rate=config.get("lr"),
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type=config.get("lr_scheduler"),
        seed=3407,
        save_strategy="epoch",
        output_dir=config.get("model_outputs"),
        dataloader_pin_memory=False,
        report_to="wandb",  # enable logging to W&B
    ),
)

trainer_stats = trainer.train()

try:
    model.save_pretrained_merged(
        config.get("model_outputs") + "/naive-ft-baseline",
        tokenizer,
        save_method="merged_16bit",
    )
except:
    pass

model.push_to_hub_merged(
    "preference-agents/enron_42k_naive_ft", tokenizer, save_method="merged_16bit"
)
