import os
import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig, AutoPeftModelForCausalLM
from trl import SFTTrainer
import transformers
import evaluate
import json
from tqdm import tqdm
from huggingface_hub import login

login()
output_dir = "./results_7b_llama_sft_newfinal_20_01_24"
# output_dir = "./results_llama_sft_dpo"

# Load the entire model on the GPU 0
device_map = {"": 0}
final_checkpoint_dir = os.path.join(output_dir, "final_checkpoint")
print(final_checkpoint_dir)
reloaded_model = AutoPeftModelForCausalLM.from_pretrained(
    final_checkpoint_dir,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    cache_dir="llama_model_7b"

)
reloaded_tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir, add_eos_token=True, use_fast=True)
print("Reloaded model")
# Merge the LoRA and the base model
merged_model = reloaded_model.merge_and_unload()
# Save the merged model
merged_dir = os.path.join(output_dir, "final_merged_checkpoint")
merged_model.save_pretrained(merged_dir)
reloaded_tokenizer.save_pretrained(merged_dir)
print("Merged model and saved it")