import torch
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

from transformers import BitsAndBytesConfig

__all__ = ["get_default_quantization_config", "prepare_model_for_qlora"]


def get_default_quantization_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )


def prepare_model_for_qlora(model, peft_setups):
    # Prepare the model for training using k-bit quantization.
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=False,
        gradient_checkpointing_kwargs={"use_reentrant": False},
    )

    # Identify modules to include in the Lora configuration for peft.
    target_modules = _find_all_linear_names(model)
    peft_config = LoraConfig(
        task_type="CAUSAL_LM",
        inference_mode=False,
        target_modules=target_modules,
        **peft_setups,
    )

    # Apply the PEFT configuration to enhance the model.
    model = get_peft_model(model, peft_config)

    # Print the model's trainable parameters for verification and debugging.
    model.print_trainable_parameters()

    return model


def _find_all_linear_names(model):
    linear_class = bnb.nn.Linear4bit
    linear_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, linear_class):
            module_name_parts = name.split(".")
            linear_module_names.add(
                module_name_parts[0]
                if len(module_name_parts) == 1
                else module_name_parts[-1]
            )
        if "lm_head" in linear_module_names:  # Exclude 'lm_head' for 16-bit handling
            linear_module_names.remove("lm_head")
    return list(linear_module_names)
