import fire
import json
import pickle
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm


def main(
    database_path="",
    output_path=""
):
    model_path = 'Alibaba-NLP/gte-large-en-v1.5'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
    model.eval()

    data = json.load(open(database_path, "r"))
    sample_size = len(data)
    bsz = 16
    num_batch = sample_size // bsz + 1
    all_embeddings = []
    with torch.no_grad():
        for num in tqdm(range(num_batch)):
            input_texts = [obj["function"] for obj in data[num*bsz:(num+1)*bsz]]

            # Tokenize the input texts
            batch_dict = tokenizer(input_texts, max_length=1024, padding=True, truncation=True, return_tensors='pt')
            batch_dict = {key: batch_dict[key].cuda() for key in batch_dict}

            outputs = model(**batch_dict)
            embeddings = outputs.last_hidden_state[:, 0]
            
            # (Optionally) normalize embeddings
            embeddings = F.normalize(embeddings, p=2, dim=1).cpu()
            all_embeddings.append(embeddings)
        all_embeddings = torch.cat(all_embeddings, dim=0)
        print(all_embeddings.size())
    pickle.dump(all_embeddings.numpy(), open(output_path, "wb"))

fire.Fire(main)