import os
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import os
import torch


os.environ["WANDB_HOST"] = os.getenv("WANDB_HOST", "stargate")
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "pagent")
os.environ["WANDB_NAME"] = "phi-3-naive-finetune"

config = {
    "dataset": "preference-agents/naive-training-data-phi-3",
    "model": "unsloth/Phi-3-mini-4k-instruct",
    "max_seq_len": 4096,
    "lora_rank": 512,
    "lora_alpha": 512,
    "batch_size": 8,
    "epochs": 1,
    "lr_scheduler": "cosine",
    "lr": 2e-5,
    "warmup_steps": 100,
    "model_outputs": "/home/sumukshashidhar/foundry/enron/outputs/unsloth/",
    "model_name": "Phi-3-mini-4k-instruct",
}

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Phi-3-mini-4k-instruct",
    max_seq_length=config.get("max_seq_len"),
    dtype=None,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=config.get("lora_rank"),  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=config.get("lora_alpha"),
    lora_dropout=0,  # Supports any, but = 0 is optimized
    bias="none",  # Supports any, but = "none" is optimized
    use_gradient_checkpointing="unsloth",
    max_seq_length=config.get("max_seq_len"),
    random_state=3407,
    use_rslora=False,  # We support rank stabilized LoRA
    loftq_config=None,  # And LoftQ
)

train_dataset = load_dataset(config.get("dataset"), split="train")
EOS_TOKEN = tokenizer.eos_token

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=config.get("max_seq_len"),
    dataset_num_proc=2,
    packing=False,  # Can make training 5x faster for short sequences.
    args=TrainingArguments(
        per_device_train_batch_size=config.get("batch_size"),
        gradient_accumulation_steps=4,
        warmup_steps=config.get("warmup_steps"),
        num_train_epochs=config.get("epochs"),
        learning_rate=config.get("lr"),
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type=config.get("lr_scheduler"),
        seed=3407,
        save_strategy="epoch",
        output_dir=config.get("model_outputs"),
        dataloader_pin_memory=False,
        run_name="phi-3-naive-ft",
        report_to="wandb",  # enable logging to W&B
    ),
)

trainer_stats = trainer.train()

model.save_pretrained_merged(
    config["model_outputs"] + "/phi-3-naive-ft", tokenizer, save_method="merged_16bit"
)
model.push_to_hub_merged(
    "preference-agents/phi-3-naive-ft", tokenizer, save_method="merged_16bit"
)
