import gc
import ray
import faiss
import time
import numpy as np

from openea.modules.utils.util import task_divide, merge_dic

# ray.init()


def generate_neighbours_faiss(entity_embeds, entity_list, neighbors_num, frags_num=8):
    ent_frags = task_divide(np.array(entity_list), frags_num)
    ent_frag_indexes = task_divide(np.array(range(len(entity_list))), frags_num)
    dic = dict()
    rest = []

    index = faiss.IndexFlatL2(entity_embeds.shape[1])  # build the index
    index.add(entity_embeds)  # add vectors to the index
    for i in range(len(ent_frags)):
        # res = find_neighbours.remote(ent_frags[i], np.array(entity_list), entity_embeds[ent_frag_indexes[i], :],
        #                              entity_embeds, neighbors_num)
        res = find_neighbours_faiss_batch.remote(ent_frags[i], np.array(entity_list), entity_embeds[ent_frag_indexes[i], :],
                                           index, neighbors_num)
        rest.append(res)
    for res in ray.get(rest):
        dic = merge_dic(dic, res)

    # num = entity_embeds.shape[0]
    # dim = entity_embeds.shape[1]
    # start = time.time()
    # _, index_mat = index.search(entity_embeds, neighbors_num)
    # print("index time = {:.3f} s ".format(time.time() - start))
    # for i in range(num):
    #     neighbors_index = index_mat[i,]
    #     neighbors = np.array(entity_list)[neighbors_index].tolist()
    #     dic[entity_list[i]] = neighbors

    del index, entity_embeds
    gc.collect()
    return dic


@ray.remote(num_cpus=1)
def find_neighbours(frags, entity_list, sub_embed, embed, k):
    dic = dict()
    sim_mat = np.matmul(sub_embed, embed.T)
    for i in range(sim_mat.shape[0]):
        sort_index = np.argpartition(-sim_mat[i, :], k)
        neighbors_index = sort_index[0:k]
        neighbors = entity_list[neighbors_index].tolist()
        dic[frags[i]] = neighbors
    del sim_mat
    gc.collect()
    return dic


@ray.remote(num_cpus=1)
def find_neighbours_faiss(frags, entity_list, sub_embed, index, k):
    dic = dict()
    num = sub_embed.shape[0]
    dim = sub_embed.shape[1]
    for i in range(num):
        query = sub_embed[i, :].reshape(1, dim)
        _, index_vec = index.search(query, k)
        neighbors_index = index_vec[0,]
        neighbors = entity_list[neighbors_index].tolist()
        dic[frags[i]] = neighbors
    return dic


@ray.remote(num_cpus=1)
def find_neighbours_faiss_batch(frags, entity_list, sub_embed, index, k):
    dic = dict()
    _, index_mat = index.search(sub_embed, k)
    for i in range(sub_embed.shape[0]):
        neighbors_index = index_mat[i,]
        neighbors = entity_list[neighbors_index].tolist()
        dic[frags[i]] = neighbors
    del index_mat
    return dic
