import fire
import json
import faiss
import pickle
import numpy as np

import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm


def load_model():
    model_path = 'Alibaba-NLP/gte-large-en-v1.5'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
    model.eval()
    return model, tokenizer


def load_fun_database(path):
    data = json.load(open(path, "r"))
    funs = [obj["function"] for obj in data]
    return funs


def load_fun_embeds(path):
    data = pickle.load(open(path, "rb"))
    return data


def load_query(path):
    data = json.load(open(path, "r"))
    return data


def main(
    fun_path,
    embed_path,
    query_path,
    output_path
):
    funs = load_fun_database(fun_path)
    embeds = load_fun_embeds(embed_path)
    data = load_query(query_path)
    model, tokenizer = load_model()

    index = faiss.IndexFlatL2(1024)
    index.add(embeds)

    new_data = []
    no_new = 0
    for obj in tqdm(data):
        querys = obj["extract_modular"]
        try:
            if len(querys) > 0:
                with torch.no_grad():
                    batch_dict = tokenizer(querys, max_length=1024, padding=True, truncation=True, return_tensors='pt')
                    batch_dict = {key: batch_dict[key].cuda() for key in batch_dict}
                    outputs = model(**batch_dict)
                    query_embeds = outputs.last_hidden_state[:, 0]                
                    query_embeds = F.normalize(query_embeds, p=2, dim=1).cpu().numpy()
                _, I = index.search(query_embeds, 1)
                similar_modular = []
                for result in I:
                    idx = result[0]
                    fun = funs[idx]
                    similar_modular.append(fun)
                obj["similar_modular"] = similar_modular
            else:
                obj["similar_modular"] = []
                no_new += 1
            new_data.append(obj)
        except:
            pass
        i += 1
    json.dump(new_data, open(output_path, "w"), indent=4)
    print(no_new)

fire.Fire(main)
            