#!/usr/bin/env python
# coding: utf-8

import argparse
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, DPOTrainer
from datasets import Dataset
from tqdm import tqdm
from utils import check_matching

def train_model(dataset_path, base_model_name, refined_model_path, epochs, learning_rate, sequential_training=False):
    # Data Preparation
    df = pd.read_csv(dataset_path)
    df['text'] = df.apply(lambda row: f"{row['question']}\n\n### Instruction:\nChoose the answer to the question only from options A, B, C, D.\n{row['question']}\n\n### Response:\n```{row['choices']}```", axis=1)
    dataset = Dataset.from_pandas(df[['text']])

    # Model and Tokenizer Initialization
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'eos_token': '</s>'})

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=False
    )

    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto"  # Automatically map the model to available CUDA devices
    )
    model.config.use_cache = False

    model = prepare_model_for_kbit_training(model)

    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]
    )
    model = get_peft_model(model, peft_config)

    # Training Arguments
    training_arguments = TrainingArguments(
        output_dir="./results",
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        optim="paged_adamw_8bit",
        save_steps=5000,
        logging_steps=30,
        learning_rate=learning_rate,
        weight_decay=0.001,
        fp16=False,
        bf16=False,
        max_grad_norm=0.3,
        max_steps=-1,
        warmup_ratio=0.3,
        group_by_length=True,
        lr_scheduler_type="constant"
    )

    if sequential_training:
        # SFT Training
        sft_trainer = SFTTrainer(
            model=model,
            args=training_arguments,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            dataset_text_field="text",
            packing=False
        )
        sft_trainer.train()

        # Save the SFT trained model
        sft_trained_model_path = refined_model_path + "/sft_checkpoint"
        sft_trainer.model.save_pretrained(sft_trained_model_path)

        # DPO Training using the SFT trained model as a base
        model_ref = AutoModelForCausalLM.from_pretrained(sft_trained_model_path)
        model_ref.config.use_cache = False

        dpo_trainer = DPOTrainer(
            model=model,
            model_ref=model_ref,
            args=training_arguments,
            beta=0.1,  # Beta hyperparameter for DPO
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            max_prompt_length=1024,
            max_length=1024
        )
        dpo_trainer.train()

        # Save the DPO trained model
        dpo_trainer.save_model(refined_model_path + "/dpo_checkpoint")
        dpo_trainer.model.save_pretrained(refined_model_path + "/final_checkpoint")
    else:
        # SFT Training Only
        sft_trainer = SFTTrainer(
            model=model,
            args=training_arguments,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            dataset_text_field="text",
            packing=False
        )
        sft_trainer.train()
        sft_trainer.save_model(refined_model_path)
        sft_trainer.model.save_pretrained(refined_model_path + "/final_checkpoint")

def evaluate_model(dataset_path, refined_model_path):
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(refined_model_path)
    model = PeftModel.from_pretrained(refined_model_path)

    df = pd.read_csv(dataset_path)
    correct_preds = 0
    for _, row in df.iterrows():
        prompt = f"{row['question']}\n\n### Instruction:\nChoose the answer to the question only from options A, B, C, D.\n{row['question']}\n\n### Response:\n```{row['choices']}```"
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if check_matching(answer, row['correct_answer']):
            correct_preds += 1

    acc = (correct_preds * 100) / len(df)
    print(f"Accuracy: {acc}%")
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train and evaluate a model with options for SFT and sequential SFT followed by DPO.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset CSV file.")
    parser.add_argument("--base_model_name", type=str, required=True, help="Name of the base model.")
    parser.add_argument("--refined_model_path", type=str, required=True, help="Path to save the refined model.")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate for training.")
    parser.add_argument("--sequential_training", action='store_true', help="Enable sequential SFT followed by DPO training.")
    args = parser.parse_args()

    if args.action == "train":
        train_model(args.dataset_path, args.base_model_name, args.refined_model_path, args.epochs, args.learning_rate, args.sequential_training)
    elif args.action == "evaluate":
        evaluate_model(args.dataset_path, args.refined_model_path)

