from transformers import AutoModelForCausalLM
import torch
import argparse

# MODEL_NAME = "EleutherAI/gpt-neox-20b"
# WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/gptneox"


# WHEN DOWNLOADING NEW WEIGHST:
# chmod -R 775 /localdata1/EmEx to allow other people to use it

MODEL_NAME = "togethercomputer/GPT-NeoXT-Chat-Base-20B"
WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/gptneoxchatbase"


def save_weights(fp16:bool=False,bf16:bool=False):
    model_name = MODEL_NAME
    print("Load and save weights...")
    if fp16:
        model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16)
    elif bf16:
        model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name)
    model.save_pretrained(WEIGHTS_FOLDER)

if __name__ =="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fp16", default=False,action="store_true")
    parser.add_argument("--bf16", default=False,action="store_true")
    args = parser.parse_args()
    if args.fp16:
        save_weights(fp16=True)
    elif args.bf16:
        save_weights(bf16=True)
    else:
        save_weights()    