#!/bin/python


import argparse
import logging
import torch
import sys
import pdb
import os


sys.path.append(
            os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from utils.module.lora import LinearLayer_LoRA, convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters

def parse_args():
    parser = argparse.ArgumentParser(description="Convert the lora model to hf model")
    parser.add_argument(
        "--model_path",
        type=str,
        help="Path to model",
        required=True,
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        help="Path to converted model",
        required=True,
    )
    parser.add_argument(
        "--lora_dim",
        type=int,
        help="lora_dim",
        required=True,
    )
    args = parser.parse_args()
    return args

# convert the LoRA layer to linear layer
def convert_lora_to_linear_layer(model, lora_scaling=1):
    repalce_name = []
    for name, module in  model.items():
        if 'lora' in name:
            repalce_name.append(name)
    for name in repalce_name:
        if 'lora_left' in name:
            module_name = '.'.join(name.split('.')[: -1])
            weight = model[module_name + '.weight']
            lora_left_weight = model[module_name + '.lora_left_weight']
            lora_right_weight = model[module_name + '.lora_right_weight']
            weight += lora_scaling * torch.matmul(
                lora_left_weight.t(), lora_right_weight.t())
            model[module_name + '.weight'] = weight
            del model[module_name + '.lora_left_weight']
            del model[module_name + '.lora_right_weight']
    return model



def main():
    args = parse_args()

    device = torch.device("cuda:0")
    model=torch.load(args.model_path,  map_location=device)
    model=convert_lora_to_linear_layer(model, lora_scaling=1/args.lora_dim)
    os.makedirs(args.save_dir, exist_ok=True)
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(args.save_dir, WEIGHTS_NAME)
    torch.save(model, output_model_file)

                       



if __name__ == "__main__":
    main()

