#!/usr/bin/env python
# coding: utf-8

import argparse
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
from datasets import Dataset
from tqdm import tqdm
from utils import check_matching

def train_model(dataset_path, base_model_name, refined_model_path, epochs, learning_rate):
    # Data Preparation
    def convert_to_format(row):
        prompt = row['question']
        answer = row['choices']
        instruction = """Choose the answer to the question only from options A, B, C, D."""
        input_str = str(prompt)
        response = f"""```{Question: '{prompt}', Answer: '{answer}'}```"""
        text = prompt + "\n\n### Instruction:\n" + instruction + input_str + "\n" + "\n### Response:\n" + response
        return pd.Series([instruction, input_str, response, text])

    df = pd.read_csv(dataset_path)
    new_df = df.apply(convert_to_format, axis=1)
    new_df.columns = ['instruction', 'input', 'output', 'text']
    dataset = Dataset.from_pandas(new_df)

    # Model and Tokenizer Initialization
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_eos_token = True

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=False
    )

    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map={"": 0}
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
    )
    model = get_peft_model(model, peft_config)

    # Training Arguments
    training_arguments = TrainingArguments(
        output_dir="./results",
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        optim="paged_adamw_8bit",
        save_steps=5000,
        logging_steps=30,
        learning_rate=learning_rate,
        weight_decay=0.001,
        fp16=False,
        bf16=False,
        max_grad_norm=0.3,
        max_steps=-1,
        warmup_ratio=0.3,
        group_by_length=True,
        lr_scheduler_type="constant"
    )

    # Setting sft parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        max_seq_length=None,
        dataset_text_field="text",
        tokenizer=tokenizer,
        args=training_arguments,
        packing=False
    )
    
    trainer.train()
    trainer.model.save_pretrained(refined_model_path)

def evaluate_model(dataset_path, refined_model_path):
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(refined_model_path)
    our_model = PeftModel.from_pretrained(refined_model_path)

    # Load the dataset
    df = pd.read_csv(dataset_path)
    dataset = df.apply(format_text, axis=1)

    correct_preds = 0
    for text in tqdm(dataset):
        answer = generate(text['text'])
        if check_matching(answer, text['target']):
            correct_preds += 1

    acc = (correct_preds * 100) / len(dataset)
    print(f"Accuracy: {acc}%")
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train and evaluate a model.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset CSV file.")
    parser.add_argument("--base_model_name", type=str, required=True, help="Name of the base model.")
    parser.add_argument("--refined_model_path", type=str, required=True, help="Path to save the refined model.")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate for training.")
    parser.add_argument("--action", type=str, choices=["train", "evaluate"], required=True, help="Whether to train or evaluate the model.")

    args = parser.parse_args()

    if args.action == "train":
        train_model(args.dataset_path, args.base_model_name, args.refined_model_path, args.epochs, args.learning_rate)
    elif args.action == "evaluate":
        evaluate_model(args.dataset_path, args.refined_model_path)
