from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import load_dataset
import torch.nn as nn
import transformers
import torch


def create_configs():
    # bitsandbytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["query_key_value", "dense"],
            lora_dropout=0.05,
            bias='none',
            task_type="CAUSAL_LM"
        )
    return bnb_config, lora_config
    
    
def print_trainable_parameters(model: nn.Module):
    trainable_params, all_params = 0, 0
    
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            
    # print the number of parameters
    args = (all_params, trainable_params, float(trainable_params) / all_params * 100)
    print("Number of params: {}, Trainable parameters: {}, Percentage: {}%".format(*args))
    

def load_data(tokenizer: AutoTokenizer):
    data = load_dataset('Abirate/english_quotes')
    data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
    return data


def prepare_trainer(model: nn.Module, tokenizer: AutoTokenizer, data, output_path: str):
    trainer = transformers.Trainer(
        model=model,
        train_dataset=data["train"],
        args=transformers.TrainingArguments(per_device_train_batch_size=1, gradient_accumulation_steps=4,
                                             warmup_steps=10, learning_rate=2e-4, fp16=True, logging_steps=1,
                                             output_dir=output_path, optim="paged_adamw_8bit"),
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
    )
    return trainer


if __name__ == '__main__':

    # name of the pretrained model
    model_name = 'EleutherAI/gpt-neox-20b'

    # path of the experiment
    output_path = '/data1/flo/llm/gpt_neox/'

    # extract the configurations for lora training and quantization
    bnb_config, lora_config = create_configs()

    # load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # load the model
    model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")

    # enable gradient checkpoining
    model.gradient_checkpointing_enable()

    # prepare the model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # prepare the model for parameter efficient finetuning using Lora
    print('Converting model into its parameter efficient version...')
    model = get_peft_model(model, lora_config)

    # print the number of trainable parameters
    print_trainable_parameters(model)

    # load the dataset
    print('Loading training data...')
    data = load_data(tokenizer)

    # prepare a huggingface trainer
    print('Preparing the trainer...')
    trainer = prepare_trainer(model, tokenizer, data, output_path)

    # silence the warnings (re-enable for inference)
    model.config.use_cache = False

    # train the model
    trainer.train()

    # save the model to disc
    model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model
    model_to_save.save_pretrained(output_path)
