#!/usr/bin/env python
# coding: utf-8

import argparse
import os
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer, DPOTrainer

def train_model(dataset_path, base_model_name, refined_model_path, epochs, learning_rate, sequential_training=False):
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  
    
    # Quantization Config
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False
    )
    
    # Model
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=quant_config,
        device_map="auto"  # Automatically map the model to available CUDA devices
    )
    model.config.use_cache = False

    # Load Dataset
    data_df = pd.read_csv(dataset_path)
    dataset = Dataset.from_pandas(data_df)

    # PEFT Configuration
    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=8,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    # Training Arguments
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=epochs,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=1,
        optim="paged_adamw_32bit",
        save_steps=50,
        logging_steps=10,
        learning_rate=learning_rate,
        weight_decay=0.01,
        fp16=False,
        bf16=False,
        max_grad_norm=1.0,
        max_steps=-1,
        warmup_ratio=0.1,
        group_by_length=False,
        lr_scheduler_type="linear",
        push_to_hub=False
    )

    if sequential_training:
        # SFT Training
        sft_trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config
        )
        sft_trainer.train()
        sft_trained_model = sft_trainer.model

        # Save the SFT trained model
        sft_trained_model.save_pretrained(refined_model_path + "/sft_checkpoint")

        # DPO Training using the SFT trained model as a base
        model_ref = AutoModelForCausalLM.from_pretrained(refined_model_path + "/sft_checkpoint")
        model_ref.config.use_cache = False

        dpo_trainer = DPOTrainer(
            model=sft_trained_model,
            model_ref=model_ref,
            args=training_args,
            beta=0.1,  # Beta hyperparameter for DPO
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            max_prompt_length=1024,
            max_length=1024
        )
        dpo_trainer.train()

        # Save the DPO trained model
        dpo_trainer.save_model(refined_model_path + "/dpo_checkpoint")
        dpo_trainer.model.save_pretrained(refined_model_path + "/final_checkpoint")
    else:
        # Single training type
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config
        )
        trainer.train()
        trainer.save_model(refined_model_path)
        model.save_pretrained(refined_model_path + "/final_checkpoint")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model with options for SFT, DPO, or sequential SFT followed by DPO.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset CSV file.")
    parser.add_argument("--base_model_name", type=str, required=True, help="Name of the base model.")
    parser.add_argument("--refined_model_path", type=str, required=True, help="Path to save the refined model.")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate for training.")
    parser.add_argument("--sequential_training", action='store_true', help="Enable sequential SFT followed by DPO training.")
    args = parser.parse_args()

    train_model(args.dataset_path, args.base_model_name, args.refined_model_path, args.epochs, args.learning_rate, args.sequential_training)

