import math
import multiprocessing as mp
import random
import time
import gc

import tensorflow as tf
import numpy as np
from sklearn import preprocessing

import openea.modules.load.read as rd
import openea.modules.train.batch as bat
from openea.modules.finding.evaluation import early_stop
from openea.modules.utils.util import generate_out_folder
from openea.modules.utils.util import load_session
from openea.modules.utils.util import task_divide
from openea.modules.base.initializers import init_embeddings, orthogonal_init
from openea.modules.base.losses import margin_loss
from openea.modules.base.optimizers import generate_optimizer

from utils import MyKGs, find_alignment
from eval import valid, test, greedy_alignment, eval_margin, eval_dangling_binary_cls, get_results
from nn_search import generate_neighbours
from nn_search_faiss import generate_neighbours_faiss
import ipdb
import os


def search_kg1_to_kg2_1nn_neighbor(embeds1, embeds2, ents2, mapping_mat, return_sim=False, soft_nn=10):
    if mapping_mat is not None:
        embeds1 = np.matmul(embeds1, mapping_mat) # why is np now?
        embeds1 = preprocessing.normalize(embeds1)
    sim_mat = np.matmul(embeds1, embeds2.T)
    # ipdb.set_trace()
    nearest_pairs = find_alignment(sim_mat, soft_nn)
    nns = [ents2[x[0][1]] for x in nearest_pairs]
    if return_sim:
        sim_list = []
        for pair in nearest_pairs:
            sim_list.append(sim_mat[pair[0][0], pair[0][1]]) # for each source ent, return the similarity of the nearest counterpart.
        return nns, sim_list 
    return nns


def search_kg1_to_kg2_ordered_all_nns(embeds1, embeds2, ents2, mapping_mat, return_all_sim=False, soft_nn=10):
    # TODO: check whether it's correct
    # ipdb.set_trace()
    # if embeds2.shape[0] > 100000:
    #     random_ent_index = random.sample(range(len(ents2)), len(ents2) * 2 // 3)
    #     ents2 = np.array(ents2)[random_ent_index].tolist()
    #     embeds2 = embeds2[random_ent_index,]
    #     assert embeds2.shape[0] == len(ents2)
    if mapping_mat is not None:
        embeds1 = np.matmul(embeds1, mapping_mat) # why is np now?
        embeds1 = preprocessing.normalize(embeds1)
    sim_mat = np.matmul(embeds1, embeds2.T)
    nearest_pairs = find_alignment(sim_mat, soft_nn)
    nns = [ents2[x[0][1]] for x in nearest_pairs]
    if return_all_sim:
        sim_list = []
        for idx, pairs in enumerate(nearest_pairs):
            cur_sim_list = []
            for elements in pairs:
                s, t = elements[0], elements[1]
                cur_sim_list.append(sim_mat[s, t])
            idx_sorted = np.argsort(cur_sim_list)[::-1]
            nearest_pairs[idx] = np.array(pairs, dtype=np.int32)[idx_sorted]
            sim_list.append(np.array(cur_sim_list)[idx_sorted])
            # sim_list.append(sim_mat[pair[0][0], pair[0][1]]) # for each source ent, return the similarity of the nearest counterpart.
        nns = [[ents2[x[1]] for x in pairs] for pairs in nearest_pairs] # obtain all nns (for each source ent, we have soft_nn nns target
        del sim_mat
        return nns, sim_list 
    del sim_mat
    return nns


def search_kg2_to_kg1_ordered_all_nns(embeds2, embeds1, ents1, mapping_mat, return_all_sim=False, soft_nn=10):
    # TODO: check whether it's correct
    # if embeds1.shape[0] > 100000:
    #     random_ent_index = random.sample(range(len(ents1)), len(ents1) * 2 // 3)
    #     ents1 = np.array(ents1)[random_ent_index].tolist()
    #     embeds1 = embeds1[random_ent_index,]
    #     assert embeds1.shape[0] == len(ents1)
    if mapping_mat is not None:
        embeds1 = np.matmul(embeds1, mapping_mat) # why is np now?
        embeds1 = preprocessing.normalize(embeds1)
    # ipdb.set_trace()
    s2t_topk = embeds2.shape[1] # 2, embeds2: (batch_size, 2, d)
    batch_size = embeds2.shape[0]
    embeds2 = embeds2.reshape((-1, embeds1.shape[1]))
    reversed_sim_mat = np.matmul(embeds1, embeds2.T).T
    nearest_pairs = find_alignment(reversed_sim_mat, soft_nn)
    nns = [ents1[x[0][1]] for x in nearest_pairs]
    if return_all_sim:
        sim_list = []
        for idx, pairs in enumerate(nearest_pairs):
            cur_sim_list = []
            for elements in pairs:
                t, s = elements[0], elements[1]
                cur_sim_list.append(reversed_sim_mat[t, s])
            idx_sorted = np.argsort(cur_sim_list)[::-1]
            nearest_pairs[idx] = np.array(pairs, dtype=np.int32)[idx_sorted]
            sim_list.append(np.array(cur_sim_list)[idx_sorted])
            # sim_list.append(sim_mat[pair[0][0], pair[0][1]]) # for each source ent, return the similarity of the nearest counterpart.
        nns = [[ents1[x[1]] for x in pairs] for pairs in nearest_pairs] # obtain all nns (for each source ent, we have soft_nn nns target
        sim_list = np.array(sim_list).reshape(batch_size, -1) # shape: (batch_size, s2t_topk*t2s_topk)
        del reversed_sim_mat
        return nns, sim_list 
    del reversed_sim_mat
    return nns

def generate_model_folder(ckpt_folder, training_data_path, div_path, s2t_topk, t2s_topk, method_name):
    params = training_data_path.strip('/').split('/')
    print(ckpt_folder, training_data_path, params, div_path, method_name)
    path = params[-1]
    folder = ckpt_folder + method_name + '/' + path + "/" + div_path + str(time.strftime("%Y%m%d%H%M%S")) \
             + f" s_{s2t_topk}_t{t2s_topk}" + "/"
    print("ckpt folder:", folder)
    return folder

class MTransEV2:

    def set_kgs(self, kgs: MyKGs):
        self.kgs = kgs

    def set_args(self, args):
        self.args = args
        self.out_folder = generate_out_folder(self.args.output, self.args.training_data, self.args.dataset_division,
                                              self.__class__.__name__)
        self.ckpt_folder = generate_model_folder(self.args.output, self.args.training_data, self.args.dataset_division,
                                                 self.args.s2t_topk, self.args.t2s_topk,
                                              self.__class__.__name__)

    def init(self):
        self._define_variables()
        self._define_embed_graph()
        self._define_align_graph()
        if self.args.use_NCA_loss:
            self._define_NCA_loss_graph()
        if self.args.detection_mode == "classification":
            self._define_classification_graph()
        elif self.args.detection_mode == "margin":
            self._define_distance_margin_graph()
        elif self.args.detection_mode == "open":
            self._define_open_margin_graph()
        
        if self.args.use_adver_training: 
            print(f'self.args.use_adver_training: {self.args.use_adver_training}')
            self._define_adversarial_graph()
        if self.args.use_dangling_bin_cls:
            print(f'self.args.use_dangling_bin_cls: {self.args.use_dangling_bin_cls}')
            self._define_dangling_bin_cls()
        self.session = load_session()
        tf.global_variables_initializer().run(session=self.session)

    def __init__(self):

        self.out_folder = None
        self.args = None
        self.kgs = None

        self.session = None

        self.seed_entities1 = None
        self.seed_entities2 = None
        self.neg_ts = None
        self.neg_rs = None
        self.neg_hs = None
        self.pos_ts = None
        self.pos_rs = None
        self.pos_hs = None

        self.rel_embeds = None
        self.ent_embeds = None
        self.mapping_mat = None
        self.eye_mat = None

        self.triple_optimizer = None
        self.triple_loss = None
        self.mapping_optimizer = None
        self.mapping_loss = None

        self.mapping_mat = None

        self.flag1 = -1
        self.flag2 = -1
        self.early_stop = False

    def _define_variables(self):
        with tf.variable_scope('KG' + 'embeddings'):
            self.ent_embeds = init_embeddings([self.kgs.entities_num, self.args.dim], 'ent_embeds',
                                              self.args.init, self.args.ent_l2_norm)
            self.rel_embeds = init_embeddings([self.kgs.relations_num, self.args.dim], 'rel_embeds',
                                              self.args.init, self.args.rel_l2_norm)

        with tf.variable_scope('alignment' + 'mapping'):
            self.mapping_mat = orthogonal_init([self.args.dim, self.args.dim], 'mapping_matrix')
            self.eye_mat = tf.constant(np.eye(self.args.dim), dtype=tf.float32, name='eye')

    def _define_embed_graph(self):
        print("build embedding learning graph...")
        with tf.name_scope('triple_placeholder'):
            self.pos_hs = tf.placeholder(tf.int32, shape=[None])
            self.pos_rs = tf.placeholder(tf.int32, shape=[None])
            self.pos_ts = tf.placeholder(tf.int32, shape=[None])
            self.neg_hs = tf.placeholder(tf.int32, shape=[None])
            self.neg_rs = tf.placeholder(tf.int32, shape=[None])
            self.neg_ts = tf.placeholder(tf.int32, shape=[None])
        with tf.name_scope('triple_lookup'):
            phs = tf.nn.embedding_lookup(self.ent_embeds, self.pos_hs)
            prs = tf.nn.embedding_lookup(self.rel_embeds, self.pos_rs)
            pts = tf.nn.embedding_lookup(self.ent_embeds, self.pos_ts)
            nhs = tf.nn.embedding_lookup(self.ent_embeds, self.neg_hs)
            nrs = tf.nn.embedding_lookup(self.rel_embeds, self.neg_rs)
            nts = tf.nn.embedding_lookup(self.ent_embeds, self.neg_ts)
        with tf.name_scope('triple_loss'):
            self.triple_loss = margin_loss(phs, prs, pts, nhs, nrs, nts, self.args.embed_margin, self.args.loss_norm)
            self.triple_optimizer = generate_optimizer(self.triple_loss, self.args.learning_rate,
                                                       opt=self.args.optimizer)

    def _mapping_align_loss(self, seed_embed1, seed_embed2):
        mapped_seed_embed1 = tf.matmul(seed_embed1, self.mapping_mat)
        mapped_seed_embed1 = tf.nn.l2_normalize(mapped_seed_embed1, 1)
        with tf.name_scope('mapping_distance'):
            distance = mapped_seed_embed1 - seed_embed2
        with tf.name_scope('mapping_loss'):
            align_loss = tf.reduce_sum(tf.reduce_sum(tf.square(distance), axis=1))
        orthogonal_loss = tf.reduce_mean(
            tf.reduce_sum(tf.pow(tf.matmul(self.mapping_mat, self.mapping_mat, transpose_b=True) - self.eye_mat, 2), 1))
        return align_loss + orthogonal_loss

    def _mapping_align_marginal_loss(self, seed_embed1, seed_embed2, pos_embed1, neg_embed2):
        mapped_seed_embed1 = tf.matmul(seed_embed1, self.mapping_mat)
        mapped_seed_embed1 = tf.nn.l2_normalize(mapped_seed_embed1, 1)
        with tf.name_scope('mapping_distance'):
            distance = mapped_seed_embed1 - seed_embed2
        with tf.name_scope('mapping_loss'):
            pos_score = tf.reduce_sum(tf.square(distance), axis=1)
            align_loss = tf.reduce_sum(pos_score)
            if self.args.mapping_margin > 0.0:
                mapped_pos_embed1 = tf.matmul(pos_embed1, self.mapping_mat)
                mapped_pos_embed1 = tf.nn.l2_normalize(mapped_pos_embed1, 1)
                neg_distance = mapped_pos_embed1 - neg_embed2
                neg_score = tf.reduce_sum(tf.square(neg_distance), axis=1)
                align_loss += 0.1 * tf.reduce_sum(tf.nn.relu(tf.constant(self.args.mapping_margin) - neg_score))
            orthogonal_loss = tf.reduce_mean(tf.reduce_sum(
                tf.square(tf.matmul(self.mapping_mat, self.mapping_mat, transpose_b=True) - self.eye_mat), 1))
        return align_loss + orthogonal_loss

    def _define_align_graph(self):
        print("build alignment learning graph...")
        with tf.name_scope('seed_links_placeholder'):
            self.seed_entities1 = tf.placeholder(tf.int32, shape=[None])
            self.seed_entities2 = tf.placeholder(tf.int32, shape=[None])
            self.pos_entities1 = tf.placeholder(tf.int32, shape=[None])
            self.neg_entities2 = tf.placeholder(tf.int32, shape=[None])
        with tf.name_scope('seed_links_lookup'):
            seed_embed1 = tf.nn.embedding_lookup(self.ent_embeds, self.seed_entities1)
            seed_embed2 = tf.nn.embedding_lookup(self.ent_embeds, self.seed_entities2)
            pos_embed1 = tf.nn.embedding_lookup(self.ent_embeds, self.pos_entities1)
            neg_embed2 = tf.nn.embedding_lookup(self.ent_embeds, self.neg_entities2)
        with tf.name_scope('mapping_loss'):
            # self.mapping_loss = self._mapping_align_loss(seed_embed1, seed_embed2)
            self.mapping_loss = self._mapping_align_marginal_loss(seed_embed1, seed_embed2, pos_embed1, neg_embed2)
            self.mapping_optimizer = generate_optimizer(self.mapping_loss, self.args.learning_rate,
                                                        opt=self.args.optimizer)

    def _NCA_loss(self, embed1, embed2):
        embed1 = tf.matmul(embed1, self.mapping_mat)
        embed1 = tf.nn.l2_normalize(embed1, 1)
        self.NCA_embed1 = embed1
        with tf.name_scope('NCA_loss'):
            sim_scores = tf.matmul(embed1, tf.transpose(embed2))
            # S_diags = tf.eye(bsize) * sim_scores
            sim_diags = tf.diag(tf.linalg.tensor_diag_part(sim_scores))
            S_ = tf.exp(self.alpha * (sim_scores - self.ep))
            S_ = S_ - tf.diag(tf.linalg.tensor_diag_part(S_)) # clear diagonal
            self.S_ = S_

            # loss_diag = -tf.log(1 + tf.nn.relu(tf.reduce_sum(sim_diags, 0)))
            loss_diag = -tf.log(1 + self.beta * tf.nn.relu(tf.reduce_sum(sim_diags, 0)))
            pos_scores = tf.log(1 + tf.reduce_sum(S_, 0)) / self.alpha
            neg_scores = tf.log(1 + tf.reduce_sum(S_, 1)) / self.alpha
            self.pos_scores = pos_scores
            self.neg_scores = neg_scores
            self.in_pos_scores = 1 + tf.reduce_sum(S_, 0)
            # loss = tf.reduce_mean(pos_scores
            #                       + neg_scores
            #                       + loss_diag * self.beta)
            loss = tf.reduce_mean(pos_scores
                                    + neg_scores
                                    + loss_diag)
        return loss

    def _define_NCA_loss_graph(self):
        print("build NCA loss graph...")
        with tf.name_scope('NCA_input_placeholder'):
            # self.NCA_bsize = tf.placeholder(tf.int32, shape=[])
            self.NCA_input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.NCA_input_ents2 = tf.placeholder(tf.int32, shape=[None])
            self.alpha = tf.placeholder(tf.float32, shape=[])
            self.beta = tf.placeholder(tf.float32, shape=[])
            self.ep = tf.placeholder(tf.float32, shape=[])

        with tf.name_scope('NCA_lookup'):
            embed1 = tf.nn.embedding_lookup(self.ent_embeds, self.NCA_input_ents1)
            embed2 = tf.nn.embedding_lookup(self.ent_embeds, self.NCA_input_ents2)
            self.NCA_embed2 = embed2
        with tf.name_scope('NCA_loss'):
            self.NCA_loss = self._NCA_loss(embed1, embed2)
            self.NCA_optimizer = generate_optimizer(self.NCA_loss, self.args.learning_rate,
                                                    opt=self.args.optimizer)


    def _define_distance_margin_graph(self):
        print("build distance margin graph...")
        with tf.name_scope('entity_placeholder'):
            self.input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.input_ents2 = tf.placeholder(tf.int32, shape=[None])
        with tf.name_scope('negative_alignment_entity_placeholder'):
            self.seed_pos_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.negative_ents2 = tf.placeholder(tf.int32, shape=[None])
        with tf.name_scope('entity_lookup'):
            x1 = tf.nn.embedding_lookup(self.ent_embeds, self.input_ents1)
            seed_x1 = tf.nn.embedding_lookup(self.ent_embeds, self.seed_pos_ents1)
            if self.mapping_mat is not None:
                x1 = tf.matmul(x1, self.mapping_mat)
                x1 = tf.nn.l2_normalize(x1, 1)
                seed_x1 = tf.matmul(seed_x1, self.mapping_mat)
                seed_x1 = tf.nn.l2_normalize(seed_x1, 1)
            x2 = tf.nn.embedding_lookup(self.ent_embeds, self.input_ents2)
            negative_x2 = tf.nn.embedding_lookup(self.ent_embeds, self.negative_ents2)
        with tf.name_scope('dis_margin_loss'):
            dis1 = tf.reduce_sum(tf.square(x1 - x2), axis=1)
            dis2 = tf.reduce_sum(tf.square(seed_x1 - negative_x2), axis=1)
        dis_loss = tf.reduce_sum(tf.nn.relu(self.args.distance_margin - dis1))
        # + tf.reduce_sum(tf.nn.relu(self.args.distance_margin // 2 - dis2))
        self.dis_loss = 0.1 * dis_loss
        self.dis_optimizer = generate_optimizer(self.dis_loss, self.args.learning_rate, opt=self.args.optimizer)

    def _define_open_margin_graph(self):
        print("build open margin graph...")
        with tf.name_scope('entity_placeholder'):
            self.input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.input_ents2 = tf.placeholder(tf.int32, shape=[None, self.args.num_random])
        with tf.name_scope('entity_lookup'):
            x1 = tf.nn.embedding_lookup(self.ent_embeds, self.input_ents1)
            if self.mapping_mat is not None:
                x1 = tf.matmul(x1, self.mapping_mat)
                x1 = tf.nn.l2_normalize(x1, 1)
            x1 = tf.reshape(x1, [-1, 1, self.args.dim])
            x2 = tf.nn.embedding_lookup(self.ent_embeds, self.input_ents2)
        with tf.name_scope('open_margin_loss'):
            dis_mat = tf.reduce_sum(tf.square(x1 - x2), axis=2)
            avg_dis = tf.reduce_mean(dis_mat, axis=1, keepdims=True)
            dis = dis_mat - avg_dis
        self.open_loss = 0.1 * tf.reduce_sum(tf.abs(dis))
        self.open_optimizer = generate_optimizer(self.open_loss, self.args.learning_rate, opt=self.args.optimizer)

    def _define_dangling_bin_cls(self):
        print('build dangling entity binary classification graph...')
        with tf.name_scope('is_training'):
            self.is_training = tf.placeholder(tf.bool, None)
        with tf.name_scope('topk_scores'):
            self.topk_scores = tf.placeholder(tf.float32, shape=[None, self.args.s2t_topk*self.args.t2s_topk]) # (#source_ent, k_sim_scores)
        if self.args.concat_forward_reverse:
            print('consider s2t nearest topk scores')
            with tf.name_scope('s2t_topk_scores'):
                self.s2t_scores = tf.placeholder(tf.float32, shape=[None, self.args.s2t_topk])
        with tf.name_scope('bin_cls_labels'):
            self.bin_cls_labels = tf.placeholder(tf.float32, shape=[None]) # (#source_ent)
        with tf.name_scope('bin_cls_mlp'):
            if self.args.concat_forward_reverse:
                input = tf.concat([self.topk_scores, self.s2t_scores], 1)
            else:
                input = self.topk_scores
            fc1 = tf.contrib.layers.fully_connected(input, self.args.n_hid, activation_fn=tf.nn.relu)
            # fc2 = tf.contrib.layers.fully_connected(fc1, 64, activation_fn=tf.nn.relu)
            fc1 = tf.layers.dropout(fc1, rate=0.5, training=self.is_training)
            output = tf.contrib.layers.fully_connected(fc1, 1, activation_fn=None)
            output = tf.sigmoid(output)
            self.dangling_bin_cls_output = tf.reshape(output, (-1,))

        self.bin_cls_loss = tf.keras.losses.binary_crossentropy(self.bin_cls_labels, self.dangling_bin_cls_output)
        self.bin_cls_optimizer = generate_optimizer(self.bin_cls_loss, self.args.learning_rate, opt=self.args.optimizer)


    def _define_adversarial_graph(self):
        print('build graph for WGAN')

        def glorot_init(shape):
            return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))

        def discriminator(x, weights, biases):
            hidden_layer = tf.matmul(x, weights['disc_hidden1'])
            hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
            hidden_layer = tf.nn.relu(hidden_layer)
            out_layer = tf.matmul(hidden_layer, weights['disc_out'])
            out_layer = tf.add(out_layer, biases['disc_out'])
            # out_layer = tf.nn.sigmoid(out_layer)
            return out_layer

        with tf.name_scope('discriminator'):
            self.disc_input_ent1 = tf.placeholder(tf.int32, shape=[None])
            self.disc_input_ent2 = tf.placeholder(tf.int32, shape=[None])
            self.input_dangling_ent1 = tf.placeholder(tf.int32, shape=[None])
            weights1 = {
                'disc_hidden1': tf.Variable(glorot_init([self.args.dim, self.args.adver_n_hid])),
                'disc_out': tf.Variable(glorot_init([self.args.adver_n_hid, 1])),
            }
            biases1 = {
                'disc_hidden1': tf.Variable(tf.zeros([self.args.adver_n_hid])),
                'disc_out': tf.Variable(tf.zeros([1])),
            }
            input_embeds1 = tf.nn.embedding_lookup(self.ent_embeds, self.disc_input_ent1)
            embeds2 = tf.nn.embedding_lookup(self.ent_embeds, self.disc_input_ent2)
            embeds1 = tf.matmul(input_embeds1, self.mapping_mat)
            embeds_Mh = tf.nn.l2_normalize(embeds1, 1)

            input_dangling_ent1 = tf.nn.embedding_lookup(self.ent_embeds, self.input_dangling_ent1)
            dangling_embeds1 = tf.matmul(input_dangling_ent1, self.mapping_mat)
            dangling_embeds1 = tf.nn.l2_normalize(dangling_embeds1, 1)

            self.disc_fake = discriminator(embeds_Mh, weights1, biases1)
            self.disc_real = discriminator(embeds2, weights1, biases1)
            self.dangling_scores = discriminator(dangling_embeds1, weights1, biases1)

            self.gen_loss = -tf.reduce_mean(self.disc_fake)
            if self.args.max_dangling_to_target:
                self.max_dangling_loss = self.args.lambda_loss * tf.reduce_mean(self.dangling_scores)
                # self.gen_loss += self.max_dangling_loss
            self.disc_loss = - tf.reduce_mean(self.disc_real) + tf.reduce_mean(self.disc_fake)

            disc_vars = [weights1['disc_hidden1'], weights1['disc_out'],
                          biases1['disc_hidden1'], biases1['disc_out']]

            tmp_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='KGembeddings')

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                # self.gen_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.gen_loss, var_list=tmp_list+[self.mapping_mat])
                self.gen_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.gen_loss, var_list=self.mapping_mat)
                self.disc_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.disc_loss, var_list=disc_vars)
                if self.args.max_dangling_to_target:
                    # self.max_dangling_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.max_dangling_loss, var_list=tmp_list+[self.mapping_mat])
                    self.max_dangling_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.max_dangling_loss, var_list=self.mapping_mat)
            self.disc_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in disc_vars]

    def _eval_valid_embeddings(self):
        candidate_list = self.kgs.valid_entities2 + list(self.kgs.kg2.entities_set
                                                         - set(self.kgs.train_entities2)
                                                         - set(self.kgs.valid_entities2)) # why they should remove the train_entities2?
        embeds1 = tf.nn.embedding_lookup(self.ent_embeds, self.kgs.valid_entities1).eval(session=self.session)
        embeds2 = tf.nn.embedding_lookup(self.ent_embeds, candidate_list).eval(session=self.session)
        mapping = self.mapping_mat.eval(session=self.session) if self.mapping_mat is not None else None
        return embeds1, embeds2, mapping

    def _eval_test_embeddings(self):
        candidate_list = self.kgs.test_entities2 + list(self.kgs.kg2.entities_set
                                                        - set(self.kgs.train_entities2)
                                                        - set(self.kgs.valid_entities2)
                                                        - set(self.kgs.test_entities2))
        embeds1 = tf.nn.embedding_lookup(self.ent_embeds, self.kgs.test_entities1).eval(session=self.session)
        embeds2 = tf.nn.embedding_lookup(self.ent_embeds, candidate_list).eval(session=self.session)
        mapping = self.mapping_mat.eval(session=self.session) if self.mapping_mat is not None else None
        return embeds1, embeds2, mapping

    def valid_alignment(self, stop_metric):
        print("\nevaluating alignment (relaxed setting)...")
        embeds1, embeds2, mapping = self._eval_valid_embeddings()
        hits, mrr_12, sim_list = valid(embeds1, embeds2, mapping, self.args.top_k,
                                       self.args.test_threads_num, metric=self.args.eval_metric,
                                       normalize=self.args.eval_norm, csls_k=0, accurate=False)
        print()
        return hits[0] if stop_metric == 'hits1' else mrr_12

    def get_source_and_candidates(self, source_ents_and_labels, is_test):
        total_ent_embeds = self.ent_embeds.eval(session=self.session)
        mapping_mat = self.mapping_mat.eval(session=self.session)

        source_ents = [x[0] for x in source_ents_and_labels]
        source_embeds = total_ent_embeds[np.array(source_ents),]

        if is_test:
            target_candidates = list(set(self.kgs.kg2.entities_list) -
                                     set(self.kgs.train_entities2) -
                                     set(self.kgs.valid_entities2))
        else:
            target_candidates = list(set(self.kgs.kg2.entities_list) - set(self.kgs.train_entities2))
        target_embeds = total_ent_embeds[np.array(target_candidates),]
        source_ent_y = [x[1] for x in source_ents_and_labels]
        return source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat

    def evaluate_margin(self, source_ents_and_labels, margin, is_test=False):
        print("dangling entity detection...")

        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(source_ents_and_labels, is_test)
        nns, sims = search_kg1_to_kg2_1nn_neighbor(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                   return_sim=True)
        dis_vec = 1 - np.array(sims)
        mean_dis = np.mean(dis_vec) # mean distance from the source to its nearest counterpart.
        print(mean_dis, dis_vec)
        # input('check the sims vector.....')
        dis_list = dis_vec.tolist()

        return eval_margin(source_ents, dis_list, source_ent_y, margin=mean_dis)
        # Return the predicted matchable entities.


    def eval_dangling_detection_cls(self, source_ents_and_labels, threshold=0.5, is_test=False):
        # TODO
        print("using binary classification for determining whether it's dangling ")
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(source_ents_and_labels, is_test)
        ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                           return_all_sim=True,
                                                           soft_nn=self.args.s2t_topk)
        # returned ent12 are indiced by the real indices in the whole embedding table
        total_ent_embeds = self.ent_embeds.eval(session=self.session)
        retrieved_target_embeds = total_ent_embeds[np.array(ent12, dtype=np.int32),]  # shape (batch_size, s2t_topk, d_emb)
        # ipdb.set_trace()
        reversed_ent21, t2s_sim = search_kg2_to_kg1_ordered_all_nns(retrieved_target_embeds, source_embeds, source_ents,
                                                                    mapping_mat, return_all_sim=True,
                                                                    soft_nn=self.args.t2s_topk)
        t2s_sim = np.array(t2s_sim)
        print(f'len(reversed_ent21): {len(reversed_ent21)}, t2s_sim.shape: {t2s_sim.shape}')
        if self.args.concat_forward_reverse:
            s2t_sim = np.array(s2t_sim)
            print(f's2t_sim.shape: {s2t_sim.shape}')
            y_predict = self.session.run(fetches=self.dangling_bin_cls_output,
                                         feed_dict={self.topk_scores: t2s_sim,
                                                    self.s2t_scores: s2t_sim,
                                                    self.is_training: False})
        else:
            y_predict = self.session.run(fetches=self.dangling_bin_cls_output,
                                         feed_dict={self.topk_scores: t2s_sim,
                                                    self.is_training: False})
        # batch_loss, _ = self.session.run(fetches=[self.bin_cls_loss, self.bin_cls_optimizer],
        #                                  feed_dict={self.topk_scores: t2s_sim,
        #                                             self.bin_cls_labels: batch_labels})
        print(f'type(y_predict): {type(y_predict)}, y_predict: {y_predict}')
        # y_predict = y_predict.eval(session=self.session)
        return eval_dangling_binary_cls(source_ents, y_predict, source_ent_y, threshold)

    def real_entity_alignment_evaluation(self, label11_ents, label1_num, matchable_source_ents1):
        if label11_ents is None or len(label11_ents) == 0:
            print("no predicated matchable entities")
            return 0.
        total_ent_embeds = self.ent_embeds.eval(session=self.session)
        label11_source_embeds = total_ent_embeds[np.array(label11_ents),]
        mapping_mat = self.mapping_mat.eval(session=self.session)
        label11_source_embeds = np.matmul(label11_source_embeds, mapping_mat)

        true_targets = []
        matchable_ents1 = self.kgs.valid_entities1 + self.kgs.test_entities1
        matchable_ents2 = self.kgs.valid_entities2 + self.kgs.test_entities2
        for e in label11_ents:
            idx = matchable_ents1.index(e)
            true_targets.append(matchable_ents2[idx])
        assert len(true_targets) == len(label11_ents)
        candidate_list = true_targets + list(self.kgs.kg2.entities_set
                                             - set(self.kgs.train_entities2)
                                             - set(self.kgs.valid_entities2)
                                             - set(true_targets))
        candidate_embeds = total_ent_embeds[np.array(candidate_list),]
        csls_k = 10
        if self.args.use_dangling_bin_cls: # a quick fix
            csls_k = min(10, len(label11_ents)-2)
        _, hits, _, _, _ = greedy_alignment(label11_source_embeds, candidate_embeds,
                                            self.args.top_k, self.args.test_threads_num,
                                            self.args.eval_metric, True, csls_k, False, False) # csls_k 10?
        # In the above greedy_alignment, it considers the source entities as truly predicted matchable entities.
        # in this way, the hits1 is the number of hits / the number of truly predicted matchable entities
        hits1 = hits[0]
        hits10 = hits[2]
        precision = hits1 * len(label11_ents) / label1_num
        recall = hits1 * len(label11_ents) / len(matchable_source_ents1)
        f1 = 2 * precision * recall / (precision + recall)
        recall10 = hits10 * len(label11_ents) / len(matchable_source_ents1)
        print(f'two-step results, hits1 = {hits1:.3f}, hits10 = {hits10:.3f}')
        print("two-step results, precision = {:.3f}, recall = {:.3f}, f1 = {:.3f}, recall@10 = {:.3f}\n".format(
            precision, recall, f1, recall10))
        return f1

    def two_step_evaluation_margin(self, matchable_source_ents1, dangling_source_ents1, threshold=-1, is_test=False):
        print("evaluating two-step alignment (margin)...")
        if is_test and self.args.test_batch_num > 1:
            final_label11_ents = list()
            final_label1_num = 0
            final_num_one_one, final_num_one_zero, final_num_zero_zero, final_num_zero_one, final_num_total_examples, \
            final_num_one_labels = 0, 0, 0, 0, 0, 0
            batch_num = self.args.test_batch_num
            print("test via batches...\n")
            matchable_source_ents1_tasks = task_divide(matchable_source_ents1, batch_num)
            dangling_source_ents1_tasks = task_divide(dangling_source_ents1, batch_num)
            for i in range(batch_num):
                if self.args.use_dangling_bin_cls:
                    label11_ents, label1_num, \
                    num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                        self.eval_dangling_detection_cls(matchable_source_ents1_tasks[i] +
                                                         dangling_source_ents1_tasks[i],
                                                         threshold=threshold, is_test=is_test)
                else:
                    label11_ents, label1_num, \
                    num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                        self.evaluate_margin(matchable_source_ents1_tasks[i] +
                                             dangling_source_ents1_tasks[i],
                                             self.args.distance_margin, is_test=is_test)
                final_label11_ents += label11_ents
                final_label1_num += label1_num
                final_num_one_one += num_one_one
                final_num_one_zero += num_one_zero
                final_num_zero_zero += num_zero_zero
                final_num_zero_one += num_zero_one
                final_num_total_examples += num_total_examples
                final_num_one_labels += num_one_labels
                print()
            print("final test results:")
            get_results(final_num_one_one, final_num_one_zero, final_num_zero_zero, final_num_zero_one,
                        final_num_total_examples, final_num_one_labels)
        else:
            if self.args.use_dangling_bin_cls:
                final_label11_ents, final_label1_num, \
                num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                    self.eval_dangling_detection_cls(matchable_source_ents1 + dangling_source_ents1,
                                                     threshold=threshold, is_test=is_test)
            else:
                final_label11_ents, final_label1_num, \
                num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                    self.evaluate_margin(matchable_source_ents1 + dangling_source_ents1,
                                         self.args.distance_margin, is_test=is_test)
        return self.real_entity_alignment_evaluation(final_label11_ents, final_label1_num, matchable_source_ents1)

    def test(self):
        print("\ntesting synthetic alignment...")
        if self.args.detection_mode == "margin":
            embeds1, embeds2, mapping = self._eval_test_embeddings()
        else:
            embeds1, embeds2, mapping = self._eval_test_embeddings()
        _, _, _, sim_list = test(embeds1, embeds2, mapping, self.args.top_k, self.args.test_threads_num,
                                 metric=self.args.eval_metric, normalize=self.args.eval_norm, csls_k=0, accurate=True)
        print()
        if self.args.detection_mode in ["margin", "open"]:
            self.two_step_evaluation_margin(self.kgs.test_linked_entities1,
                                            self.kgs.test_unlinked_entities1, is_test=True)
        print()

    def save(self):
        ent_embeds = self.ent_embeds.eval(session=self.session)
        rel_embeds = self.rel_embeds.eval(session=self.session)
        mapping_mat = self.mapping_mat.eval(session=self.session) if self.mapping_mat is not None else None
        rd.save_embeddings(self.out_folder, self.kgs, ent_embeds, rel_embeds, None, mapping_mat=mapping_mat)

    def eval_kg1_ent_embeddings(self):
        embeds = tf.nn.embedding_lookup(self.ent_embeds, self.kgs.kg1.entities_list)
        return embeds.eval(session=self.session)

    def eval_kg2_ent_embeddings(self):
        embeds = tf.nn.embedding_lookup(self.ent_embeds, self.kgs.kg2.entities_list)
        return embeds.eval(session=self.session)

    def eval_embeddings(self, entity_list):
        embeds = tf.nn.embedding_lookup(self.ent_embeds, entity_list)
        return embeds.eval(session=self.session)

    def launch_training_1epo(self, epoch, triple_steps, steps_tasks, training_batch_queue, neighbors1, neighbors2):
        # embed_before = self.ent_embeds.eval(session=self.session)
        self.launch_embed_training_1epo(epoch, triple_steps, steps_tasks, training_batch_queue, neighbors1, neighbors2)
        self.launch_align_training_1epo(epoch, triple_steps)
        if self.args.use_NCA_loss:
            self.launch_NCA_training_1epo(epoch, triple_steps)
        # embed_after = self.ent_embeds.eval(session=self.session)
        # print(f'embed_before[:5]: {embed_before[:5]}\nembed_after[:5]: {embed_after[:5]}')
        if self.args.detection_mode == "margin":
            self.launch_distance_margin_training_1epo(epoch, triple_steps)
        elif self.args.detection_mode == 'open':
            self.launch_open_margin_training_1epo(epoch, triple_steps)
        else:
            raise NotImplementedError(f'cannot find detection_mode: {self.args.detection_mode}')

        if self.args.use_dangling_bin_cls:
            self.launch_dangling_detection_binary_cls_training_1epo(epoch, triple_steps)
        if self.args.use_adver_training:
            self.launch_adver_training_1epo(epoch, triple_steps)

    def launch_embed_training_1epo(self, epoch, triple_steps, steps_tasks, batch_queue, neighbors1, neighbors2):
        start = time.time()
        for steps_task in steps_tasks:
            mp.Process(target=bat.generate_relation_triple_batch_queue,
                       args=(self.kgs.kg1.relation_triples_list, self.kgs.kg2.relation_triples_list,
                             self.kgs.kg1.relation_triples_set, self.kgs.kg2.relation_triples_set,
                             self.kgs.kg1.entities_list, self.kgs.kg2.entities_list,
                             self.args.batch_size, steps_task,
                             batch_queue, neighbors1, neighbors2, self.args.neg_triple_num)).start()
        epoch_loss = 0
        trained_samples_num = 0
        for i in range(triple_steps):
            batch_pos, batch_neg = batch_queue.get()
            batch_loss, _ = self.session.run(fetches=[self.triple_loss, self.triple_optimizer],
                                             feed_dict={self.pos_hs: [x[0] for x in batch_pos],
                                                        self.pos_rs: [x[1] for x in batch_pos],
                                                        self.pos_ts: [x[2] for x in batch_pos],
                                                        self.neg_hs: [x[0] for x in batch_neg],
                                                        self.neg_rs: [x[1] for x in batch_neg],
                                                        self.neg_ts: [x[2] for x in batch_neg]})
            trained_samples_num += len(batch_pos)
            epoch_loss += batch_loss
        epoch_loss /= trained_samples_num
        random.shuffle(self.kgs.kg1.relation_triples_list)
        random.shuffle(self.kgs.kg2.relation_triples_list)
        print('epoch {}, avg. triple loss: {:.4f}, cost time: {:.1f}s'.format(epoch, epoch_loss, time.time() - start))

    def launch_distance_margin_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0

        unlinked_entities1 = self.kgs.train_unlinked_entities1
        batch_size = len(unlinked_entities1) // triple_steps
        embeds = self.ent_embeds.eval(session=self.session)
        mapping_mat = self.mapping_mat.eval(session=self.session)
        steps_num = max(1, len(unlinked_entities1) // batch_size)

        ents2_candidates = list(self.kgs.kg2.entities_set - set(self.kgs.train_entities2)) # why minus set(training_ent2)?
        embeds2 = embeds[ents2_candidates,]

        other_entities1 = list(
            self.kgs.kg1.entities_set - set(self.kgs.train_entities1) - set(self.kgs.train_unlinked_entities1)) # I think it would cause a bug since self.kgs.train_unlinked_entities1 is a set: (id, label)

        neighbors_num2 = int((1 - self.args.truncated_epsilon) * len(self.kgs.kg2.entities_set))

        for i in range(steps_num):
            batch_data1 = random.sample(unlinked_entities1, batch_size) # Why randomly sample? why no sample in order?
            unlinked_ent1 = [x[0] for x in batch_data1]
            unlinked_embeds1 = embeds[np.array(unlinked_ent1),]
            unlinked_ent12 = search_kg1_to_kg2_1nn_neighbor(unlinked_embeds1, embeds2, ents2_candidates, mapping_mat,
                                                            soft_nn=self.args.soft_nn)
            other_ents1 = random.sample(other_entities1, batch_size)
            other_ents12 = random.sample(ents2_candidates, batch_size)

            batch_loss, _ = self.session.run(fetches=[self.dis_loss, self.dis_optimizer],
                                             feed_dict={self.input_ents1: unlinked_ent1,
                                                        self.input_ents2: unlinked_ent12,
                                                        self.seed_pos_ents1: other_ents1,
                                                        self.negative_ents2: other_ents12}) # dis_loss doesn't use negative samples.
            epoch_loss += batch_loss
            trained_samples_num += len(batch_data1)
        epoch_loss /= trained_samples_num
        print('epoch {}, margin loss: {:.4f}, cost time: {:.1f}s'.format(epoch, epoch_loss,
                                                                         time.time() - start))

    def launch_open_margin_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        training_data = self.kgs.train_unlinked_entities1
        batch_size = len(training_data) // triple_steps
        batch_size = int(min(batch_size, len(self.kgs.train_unlinked_entities1)) / 1.1)
        steps_num = max(1, len(self.kgs.train_unlinked_entities1) // batch_size)
        for i in range(steps_num):
            batch_data1 = random.sample(training_data, batch_size)
            ent1 = [x[0] for x in batch_data1]
            ents = []
            for ent in ent1:
                ent2 = random.sample(self.kgs.kg2.entities_list, self.args.num_random)
                ents.append(ent2)
            ents = np.array(ents)
            batch_loss, _ = self.session.run(fetches=[self.open_loss, self.open_optimizer],
                                             feed_dict={self.input_ents1: ent1,
                                                        self.input_ents2: ents})
            epoch_loss += batch_loss
            trained_samples_num += len(batch_data1)
        epoch_loss /= trained_samples_num
        print('epoch {}, open margin loss: {:.4f}, cost time: {:.1f}s'.format(epoch, epoch_loss,
                                                                              time.time() - start))

    def launch_dangling_detection_binary_cls_training_1epo(self, epoch, triple_steps):
        # labels: matchable 0, dangling 1
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0

        unlinked_entities1 = self.kgs.train_unlinked_entities1
        linked_entities1 = self.kgs.train_linked_entities1
        all_entities = unlinked_entities1 + linked_entities1
        unlinked_labels = np.ones(len(unlinked_entities1))
        linked_labels = np.zeros(len(linked_entities1))
        labels = np.hstack((unlinked_labels, linked_labels))
        print(f'# linked_ent1: {len(linked_entities1)}, unlinked_ent1: {len(unlinked_entities1)}')

        batch_size = len(all_entities) // triple_steps
        embeds = self.ent_embeds.eval(session=self.session)
        mapping_mat = self.mapping_mat.eval(session=self.session)
        steps_num = max(1, len(all_entities) // batch_size)
        ents2_candidates = list(self.kgs.kg2.entities_set - set(self.kgs.train_entities2))
        embeds2 = embeds[ents2_candidates, ]

        # ipdb.set_trace()
        perm = np.random.permutation(len(all_entities))
        all_entities = np.array(all_entities, dtype=np.int32)[perm]
        labels = labels[perm]
        epoch_correct = 0
        # sample training data, sample batch_size/2 unlinked ent and batch_size/2 linked ent
        for i in range(steps_num):
            batch_ent1 = all_entities[i*batch_size:(i+1)*batch_size]
            batch_labels = labels[i*batch_size:(i+1)*batch_size]
            ent1 = [x[0] for x in batch_ent1]
            ent_embeds1 = embeds[np.array(ent1), ]
            ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(ent_embeds1, embeds2, ents2_candidates, mapping_mat, return_all_sim=True,
                                                   soft_nn=self.args.s2t_topk)
            retrieved_target_embeds = embeds[np.array(ent12, dtype=np.int32), ] # shape (batch_size, s2t_topk, d_emb)
            # ipdb.set_trace()
            reversed_ent21, t2s_sim = search_kg2_to_kg1_ordered_all_nns(retrieved_target_embeds, ent_embeds1, ent1, mapping_mat, return_all_sim=True,
                                                    soft_nn=self.args.t2s_topk)
            t2s_sim = np.array(t2s_sim)
            # print(f'len(reversed_ent21): {len(reversed_ent21)}, t2s_sim.shape: {t2s_sim.shape}')
            if self.args.concat_forward_reverse:
                s2t_sim = np.array(s2t_sim)
                batch_loss, dangling_bin_cls_output, _ = self.session.run(
                                    fetches=[self.bin_cls_loss, self.dangling_bin_cls_output, self.bin_cls_optimizer],
                                    feed_dict={self.topk_scores: t2s_sim,
                                               self.s2t_scores: s2t_sim,
                                               self.bin_cls_labels: batch_labels,
                                               self.is_training: True})
            else:
                batch_loss, dangling_bin_cls_output, _ = self.session.run(
                                            fetches=[self.bin_cls_loss, self.dangling_bin_cls_output,self.bin_cls_optimizer],
                                            feed_dict={self.topk_scores: t2s_sim,
                                                    self.bin_cls_labels: batch_labels,
                                                       self.is_training: True})
            dangling_bin_cls_output = (dangling_bin_cls_output > 0.5).astype(np.int32)
            batch_num_correct = (dangling_bin_cls_output == batch_labels).astype(np.int32).sum()
            epoch_correct += batch_num_correct
            epoch_loss += batch_loss
            trained_samples_num += len(batch_ent1)
        epoch_loss /= trained_samples_num
        gc.collect()
        print('epoch {}, dangling detection binary cls loss: {:.4f}, detection acc: {:.4f}, cost time: {:.1f}s'.format(epoch, epoch_loss,
                                                                         epoch_correct/len(labels),time.time() - start))


    def launch_adver_training_1epo(self, epoch, triple_steps):
        # TODO
        start = time.time()
        disc_epoch_loss = 0
        gen_epoch_loss = 0
        max_dangling_epoch_loss = 0
        trained_samples_num_gen = 0
        trained_samples_num_disc = 0
        trained_samples_num_dangling = 0
        batch_size = 2 * len(self.kgs.train_links) // triple_steps
        n_critic = 5
        train_unlinked_entities1 = [x[0] for x in self.kgs.train_unlinked_entities1]
        for i in range(triple_steps):

            # ipdb.set_trace()
            for _ in range(n_critic): 
                # clip the parameters
                _ = self.session.run(fetches=[self.disc_clip])
                if self.args.mask_dangling_adver:
                    entities1 = random.sample(list(self.kgs.kg1.entities_set - set(train_unlinked_entities1)), batch_size)
                else:
                    entities1 = random.sample(self.kgs.kg1.entities_list, batch_size) # it should not include the dangling entities!!
                entities2 = random.sample(self.kgs.kg2.entities_list, batch_size)
                # disc opt
                disc_batch_loss, _, disc_fake, disc_real = self.session.run(fetches=[self.disc_loss, self.disc_opt,
                                                        self.disc_fake, self.disc_real],
                                                    feed_dict={self.disc_input_ent1: entities1,
                                                                self.disc_input_ent2: entities2})
                disc_epoch_loss += disc_batch_loss
                trained_samples_num_disc += batch_size
                # print(f'disc_fake: {disc_fake}, disc_real: {disc_real}')
            
            # gen opt
            # self.kgs.train_linked_entities1: (id, label),
            if self.args.mask_dangling_adver:
                entities1 = random.sample(list(self.kgs.kg1.entities_set - set(train_unlinked_entities1)), batch_size)
            else:
                entities1 = random.sample(self.kgs.kg1.entities_list, batch_size)

            gen_batch_loss, _ = self.session.run(fetches=[self.gen_loss, self.gen_opt],
                                                 feed_dict={self.disc_input_ent1: entities1})

            gen_epoch_loss += gen_batch_loss
            trained_samples_num_gen += batch_size

            # max opt
            if self.args.max_dangling_to_target:
                dangling_ent1 = random.sample(train_unlinked_entities1, batch_size)
                max_dangling_batch_loss, _ = self.session.run(fetches=[self.max_dangling_loss, self.max_dangling_opt],
                                                          feed_dict={self.input_dangling_ent1: dangling_ent1})
                max_dangling_epoch_loss += max_dangling_batch_loss
                trained_samples_num_dangling += batch_size
        disc_epoch_loss /= trained_samples_num_disc
        gen_epoch_loss /= trained_samples_num_gen

        if self.args.max_dangling_to_target:
            max_dangling_epoch_loss /= trained_samples_num_dangling
            print(f'epoch {epoch}, avg. discriminator loss: {disc_epoch_loss}, generator loss: {gen_epoch_loss}, '
                  f'max_dangling loss: {max_dangling_epoch_loss}, cost time: {time.time() - start:.1}s')
        else:
            print(f'epoch {epoch}, avg. discriminator loss: {disc_epoch_loss}, generator loss: {gen_epoch_loss}, cost time: {time.time() - start:.1}s')

    def launch_align_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        batch_size = 2 * len(self.kgs.train_links) // triple_steps
        kg_training_ents2 = list(self.kgs.kg2.entities_set - set(self.kgs.train_entities2))

        neg_batch_size = batch_size * self.args.mapping_neg_num

        for i in range(triple_steps):
            links_batch = random.sample(self.kgs.train_links, batch_size)
            pos_entities1 = random.sample(self.kgs.kg1.entities_list, neg_batch_size)
            neg_entities2 = random.sample(kg_training_ents2, neg_batch_size)
            batch_loss, _ = self.session.run(fetches=[self.mapping_loss, self.mapping_optimizer],
                                             feed_dict={self.seed_entities1: [x[0] for x in links_batch],
                                                        self.seed_entities2: [x[1] for x in links_batch],
                                                        self.pos_entities1: pos_entities1,
                                                        self.neg_entities2: neg_entities2})
            epoch_loss += batch_loss
            trained_samples_num += len(links_batch)
        epoch_loss /= trained_samples_num
        print('epoch {}, avg. mapping loss: {:.4f}, cost time: {:.1}s'.format(epoch, epoch_loss, time.time() - start))

    def launch_NCA_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        batch_size = 2 * len(self.kgs.train_links) // triple_steps

        for i in range(triple_steps):
            links_batch = random.sample(self.kgs.train_links, batch_size)
            ent1 = [x[0] for x in links_batch]
            ent2 = [x[1] for x in links_batch]
            batch_loss, _, S_, pos_scores, neg_scores, in_pos_scores, embed1, embed2 = self.session.run(fetches=[self.NCA_loss, self.NCA_optimizer, self.S_, self.pos_scores, self.neg_scores, self.in_pos_scores, self.NCA_embed1, self.NCA_embed2],
                                             feed_dict={self.NCA_input_ents1: ent1,
                                                        self.NCA_input_ents2: ent2,
                                                        self.alpha: self.args.NCA_alpha,
                                                        self.beta: self.args.NCA_beta,
                                                        self.ep: 0.0})
            # S_ = self.S_.eval(session=self.session)
            # print(f'S_: {S_}\n#nan in S_: {np.isnan(S_).sum()} ')
            # print(f'pos_scores: {pos_scores}\n#nan in pos_scores: {np.isnan(pos_scores).sum()}')
            # print(f'neg_scores: {neg_scores}\n#nan in neg_scores: {np.isnan(neg_scores).sum()}')
            # print(f'in_pos_scores: {in_pos_scores}\n#nan in in_pos_scores: {np.isnan(in_pos_scores).sum()}')
            # print(f'embed1: {embed1}')
            # print(f'embed2: {embed2}')
            # ipdb.set_trace()
            epoch_loss += batch_loss
            trained_samples_num += len(links_batch)
        epoch_loss /= trained_samples_num
        print('epoch {}, avg. NCA loss: {:.8f}, cost time: {:.1}s'.format(epoch, epoch_loss, time.time() - start))


    def generate_neighbors(self):
        t1 = time.time()
        assert 0.0 < self.args.truncated_epsilon < 1.0

        num1 = len(self.kgs.kg1.entities_list) // 2
        if len(self.kgs.kg1.entities_list) > 200000:
            num1 = len(self.kgs.kg1.entities_list) // 3
        if num1 > len(self.kgs.useful_entities_list1):
            kg1_random_ents = self.kgs.useful_entities_list1 + \
                              random.sample(list(set(self.kgs.kg1.entities_list) - set(self.kgs.useful_entities_list1)),
                                            num1 - len(self.kgs.useful_entities_list1))
        else:
            kg1_random_ents = self.kgs.useful_entities_list1
        embeds1 = self.eval_embeddings(kg1_random_ents)

        num2 = len(self.kgs.kg2.entities_list) // 2
        if len(self.kgs.kg2.entities_list) > 200000:
            num2 = len(self.kgs.kg2.entities_list) // 3
        if num2 > len(self.kgs.useful_entities_list2):
            kg2_random_ents = self.kgs.useful_entities_list2 + \
                              random.sample(list(set(self.kgs.kg2.entities_list) - set(self.kgs.useful_entities_list2)),
                                            num2 - len(self.kgs.useful_entities_list2))
        else:
            kg2_random_ents = self.kgs.useful_entities_list2
        embeds2 = self.eval_embeddings(kg2_random_ents)
        neighbors_num1 = int((1 - self.args.truncated_epsilon) * num1)
        neighbors_num2 = int((1 - self.args.truncated_epsilon) * num2)
        print("generating neighbors...")
        if "fr_en" in self.args.training_data:
            neighbors1 = generate_neighbours_faiss(embeds1, kg1_random_ents, neighbors_num1, frags_num=self.args.batch_threads_num)
            neighbors2 = generate_neighbours_faiss(embeds2, kg2_random_ents, neighbors_num2, frags_num=self.args.batch_threads_num)
        else:
            neighbors1 = generate_neighbours(embeds1, kg1_random_ents, neighbors_num1, frags_num=self.args.batch_threads_num)
            neighbors2 = generate_neighbours(embeds2, kg2_random_ents, neighbors_num2, frags_num=self.args.batch_threads_num)
        print("generating neighbors ({}, {}) costs {:.3f} s.".format(num1, num2, time.time() - t1))
        gc.collect()
        return neighbors1, neighbors2

    def run(self):
        t = time.time()

        triples_num = self.kgs.kg1.relation_triples_num + self.kgs.kg2.relation_triples_num
        triple_steps = int(math.ceil(triples_num / self.args.batch_size))
        steps_tasks = task_divide(list(range(triple_steps)), self.args.batch_threads_num)
        training_batch_manager = mp.Manager()
        training_batch_queue = training_batch_manager.Queue()
        neighbors1, neighbors2 = dict(), dict()

        # training
        for i in range(1, self.args.max_epoch + 1):
            self.launch_training_1epo(i, triple_steps, steps_tasks, training_batch_queue, neighbors1, neighbors2)
            # if i % self.args.detection_cls_freq == 0:
            #     self.launch_dangling_detection_binary_cls_training_1epo(i, triple_steps)
            # validation
            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                # print('training binary cls for dangling detection...')
                # self.launch_dangling_detection_binary_cls_training_1epo(i, triple_steps)
                # validation via synthetic alignment
                flag = self.valid_alignment(self.args.stop_metric)
                # self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                # validation via two-step alignment
                if self.args.detection_mode in ["margin", "open"] and i > self.args.start_class:
                    flag = self.two_step_evaluation_margin(self.kgs.valid_linked_entities1,
                                                           self.kgs.valid_unlinked_entities1)
                    self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                # early stop
                if self.early_stop or i == self.args.max_epoch:
                    break
            # truncated sampling cache
            if self.args.neg_sampling == 'truncated' and i % self.args.truncated_freq == 0:
                if neighbors1 is not None:
                    del neighbors1, neighbors2
                    gc.collect()
                    neighbors1, neighbors2 = self.generate_neighbors()
                    

        # num_detection_train_epoch = 10
        # print(f'Additionally, training dangling detection for {num_detection_train_epoch}')
        # for i in range(num_detection_train_epoch):
        #     self.launch_dangling_detection_binary_cls_training_1epo(i, triple_steps)
        # print("Training ends. Total time = {:.1f} s.".format(time.time() - t))




    ###################### Our own modification ##################
    def load_embedding(self):
        # TODO
        if self.args.test_method == 'MTransE':
            file_path = 'output/MTransEV2/zh_en/splits20211009154753'
        elif self.args.test_method == 'NCA':
            file_path = 'output/MTransEV2/zh_en/splits20211223113049'
        elif self.args.test_method == 'NCA_MR':
            file_path = 'output/MTransEV2/zh_en/splits20211227180410'
        else:
            raise AttributeError(f'cannot find self.args.test_method: {self.args.test_method}')
        ent_embeds = np.load(os.path.join(file_path, 'ent_embeds.npy'))
        print(f'ent_embeds.shape: {ent_embeds.shape}')
        mapping_mat = np.load(os.path.join(file_path, 'mapping_mat.npy'))
        print(f'mapping_mat.shape: {mapping_mat.shape}')
        self.mapping_mat = tf.convert_to_tensor(mapping_mat, np.float32)
        self.ent_embeds = tf.convert_to_tensor(ent_embeds, np.float32)


    def get_reverse(self, reverse_k=5, return_sim_scores=True):
        # TODO
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(self.kgs.test_linked_entities1 + \
                                            self.kgs.test_unlinked_entities1, is_test=True)
        mapping_mat = self.mapping_mat.eval(session=self.session)
        source_embeds = np.matmul(source_embeds, mapping_mat)
        source_embeds = preprocessing.normalize(source_embeds)
        sim_mat = np.matmul(source_embeds, target_embeds.T)
        print(f'sim_mat.shape: {sim_mat.shape}')
        nearest_pairs = find_alignment(sim_mat, k=1)
        # nns, sims = search_kg1_to_kg2_1nn_neighbor(source_embeds, target_embeds, target_candidates, mapping_mat,
        #                                            return_sim=True)
        # nns = [target_candidates[x[0][1]] for x in nearest_pairs] # the id is the original id in the embedding.
        # input(f'nearest_pairs[0]: {nearest_pairs[0]}')
        nns = [x[0][1] for x in nearest_pairs]
        print(f'Got nns from source to target, len(nns): {len(nns)}')
        # total_ent_embeds = self.ent_embeds.eval(session=self.session)
        reversed_target_embeds = target_embeds[nns, ]
        reversed_sim_mat = np.matmul(source_embeds, reversed_target_embeds.T).T # (#source, #nns) noted that #source = #nns
        print(f'reversed_sim_mat.shape: {reversed_sim_mat.shape}')
        source_list = list(range(source_embeds.shape[0]))
        reversed_nearest_pairs = find_alignment(reversed_sim_mat, k=reverse_k)
        topk_list = []
        sim_scores_list = []
        for pairs in reversed_nearest_pairs:
            topk = [pair[1] for pair in pairs]
            if return_sim_scores:
                cur_target = pairs[0][0]
                cur_sim_scores = reversed_sim_mat[cur_target, topk]
                # input(f'cur_sim_scores: {cur_sim_scores}')
                sim_scores_list.append(cur_sim_scores)
            topk_list.append(topk)

        # To get mean similarity scores
        num_not_include_linked_ent1 = 0
        num_not_include_unlinked_ent1 = 0
        num_test_linked_ent1 = len(self.kgs.test_linked_entities1)
        num_test_unlinked_ent1 = len(self.kgs.test_unlinked_entities1)
        # topk_matrix: list (#source, k), each element is a number indicating the idx of source.
        not_include_linked_ent1_list = []
        not_include_unlinked_ent1_list = []
        sim_scores_linked_ent_list = sim_scores_list[:num_test_linked_ent1]
        sim_scores_unlinked_ent_list = sim_scores_list[num_test_linked_ent1:]
        print(f'mean sim scores of linked ent: {np.array(sim_scores_linked_ent_list).mean()}')
        print(f'mean sim scores of unlinked ent: {np.array(sim_scores_unlinked_ent_list).mean()}')
        print(f'sim_scores_linked_list: {sim_scores_linked_ent_list[:5]}')
        print(f'sim_scores_unlinked_list: {sim_scores_unlinked_ent_list[:5]}')


        for source_idx in source_list:
            topk = topk_list[source_idx]
            # print(f'idx: {source_idx}, topk: {topk}')
            # input()
            if source_idx not in topk:
                if source_idx < num_test_linked_ent1:
                    num_not_include_linked_ent1 += 1
                    not_include_linked_ent1_list.append(source_idx)
                else:
                    num_not_include_unlinked_ent1 += 1
                    not_include_unlinked_ent1_list.append(source_idx)

            cur_sim_scores = sim_scores_list[source_idx]

        print(f'# total_linked_ent: {num_test_linked_ent1}, # total_unlinked_ent: {num_test_unlinked_ent1}')
        print(f'# not_include_linked_ent1: {num_not_include_linked_ent1}, # not_include_unlinked_ent1: {num_not_include_unlinked_ent1}')
        print(f'ratio of not_include linked: {num_not_include_linked_ent1/num_test_linked_ent1}, ratio of not include unlinked: {num_not_include_unlinked_ent1/num_test_unlinked_ent1}')

        # Treat not included entity as dangling entity.
        num_one_labels = num_test_unlinked_ent1  # 1 means the dangling entity, 0 means matchable.
        num_zero_labels = num_test_linked_ent1
        num_total_examples = num_one_labels + num_zero_labels
        assert num_one_labels + num_zero_labels == num_total_examples
        num_one_one = num_not_include_unlinked_ent1 # (ground truth, prediction)
        num_one_zero = num_one_labels - num_not_include_unlinked_ent1
        num_zero_one = num_not_include_linked_ent1
        num_zero_zero = num_zero_labels - num_not_include_linked_ent1

        precision, recall, f1, accu = 0.0, 0.0, 0.0, 0.0
        # print(num_one_one, num_one_zero, num_zero_zero, num_zero_one)
        assert num_one_one + num_one_zero + num_zero_zero + num_zero_one == num_total_examples
        if num_one_one > 0:
            precision = num_one_one / (num_one_one + num_zero_one)
            print('precision = {:.3f}, '.format(precision))
        if num_one_one > 0:
            recall = num_one_one / num_one_labels
            print('recall = {:.3f}, '.format(recall))
            print('f1 = {:.3f}, '.format(2 * precision * recall / (precision + recall)))
        if num_one_one + num_zero_zero > 0:
            print('accuracy = {:.3f}, '.format((num_one_one + num_zero_zero) / num_total_examples))


    def get_high_order_neigh_sim_dist(self, concat_forward=True, return_sim_scores=True):
        # TODO
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(self.kgs.test_linked_entities1 + \
                                            self.kgs.test_unlinked_entities1, is_test=True)
        ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                           return_all_sim=True,
                                                           soft_nn=self.args.s2t_topk)
        # returned ent12 are indiced by the real indices in the whole embedding table
        total_ent_embeds = self.ent_embeds.eval(session=self.session)
        retrieved_target_embeds = total_ent_embeds[np.array(ent12, dtype=np.int32),]  # shape (batch_size, s2t_topk, d_emb)
        # ipdb.set_trace()
        reversed_ent21, t2s_sim = search_kg2_to_kg1_ordered_all_nns(retrieved_target_embeds, source_embeds, source_ents,
                                                                    mapping_mat, return_all_sim=True,
                                                                    soft_nn=self.args.t2s_topk)
        t2s_sim = np.array(t2s_sim)
        s2t_sim = np.array(s2t_sim)
        if concat_forward:
            total_sim = np.hstack((t2s_sim, s2t_sim))
        else:
            total_sim = t2s_sim

        # To get mean similarity scores
        num_test_linked_ent1 = len(self.kgs.test_linked_entities1)
        num_test_unlinked_ent1 = len(self.kgs.test_unlinked_entities1)
        sim_dist_linked_ent = total_sim[:num_test_linked_ent1]
        sim_dist_unlinked_ent = total_sim[num_test_linked_ent1:]
        print(f'mean sim scores of linked ent: {np.array(sim_dist_linked_ent).mean()}')
        print(f'mean sim scores of unlinked ent: {np.array(sim_dist_unlinked_ent).mean()}')
        print(f'sim_scores_linked_list: {sim_dist_linked_ent[:5]}')
        print(f'sim_scores_unlinked_list: {sim_dist_unlinked_ent[:5]}')

        folder = f'concat_{concat_forward}_s{self.args.s2t_topk}_t{self.args.t2s_topk}_sim_dist'
        if not os.path.exists(folder):
            os.mkdir(folder)
        np.save(os.path.join(folder, 'sim_dist_linked.npy'), sim_dist_linked_ent)
        np.save(os.path.join(folder, 'sim_dist_unlinked.npy'), sim_dist_unlinked_ent)


    def get_first_order_neigh_sim_dist(self, return_sim_scores=True):
        # TODO
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(self.kgs.test_linked_entities1 + \
                                           self.kgs.test_unlinked_entities1, is_test=True)
        ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                           return_all_sim=True,
                                                           soft_nn=self.args.s2t_topk)

        s2t_sim = np.array(s2t_sim)
        total_sim = s2t_sim

        # To get mean similarity scores
        num_test_linked_ent1 = len(self.kgs.test_linked_entities1)
        num_test_unlinked_ent1 = len(self.kgs.test_unlinked_entities1)
        sim_dist_linked_ent = total_sim[:num_test_linked_ent1]
        sim_dist_unlinked_ent = total_sim[num_test_linked_ent1:]
        print(f'mean sim scores of linked ent: {np.array(sim_dist_linked_ent).mean()}')
        print(f'mean sim scores of unlinked ent: {np.array(sim_dist_unlinked_ent).mean()}')
        print(f'sim_scores_linked_list: {sim_dist_linked_ent[:5]}')
        print(f'sim_scores_unlinked_list: {sim_dist_unlinked_ent[:5]}')

        folder = f'first_order_s{self.args.s2t_topk}_sim_dist'
        if not os.path.exists(folder):
            os.mkdir(folder)
        np.save(os.path.join(folder, 'sim_dist_linked.npy'), sim_dist_linked_ent)
        np.save(os.path.join(folder, 'sim_dist_unlinked.npy'), sim_dist_unlinked_ent)

    def get_counts(self, input_array):
        counts = dict()
        array_flatten = input_array.flatten()
        for i in array_flatten:
            counts[i] = counts.get(i, 0) + 1
        return counts

    def top_k_tot_values(self, input_dict, k):
        topk_dict = dict(
            sorted(input_dict.items(), key=lambda item: item[1], reverse=True)[:k])
        return sum(topk_dict.values())

    def get_appear_times(self):
        # TODO
        # this function is to measure the times of each target entity appear in the nearest neighbors of source entity.
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(self.kgs.test_linked_entities1 + \
                                           self.kgs.test_unlinked_entities1, is_test=True)
        ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                           return_all_sim=True,
                                                           soft_nn=10)
        ent12 = np.array(ent12)
        s2t_sim = np.array(s2t_sim)
        print(f'ent12[:5, :]: {ent12[:5, :]}\ns2t_sim[:5, :]: {s2t_sim[:5,:]}')
        # ent12: (#source, 10) ordered list
        top1_array = np.array(ent12)[:,:1]
        top5_array = np.array(ent12)[:,:5]
        top10_array = np.array(ent12)

        top1_counts = self.get_counts(top1_array)
        top5_counts = self.get_counts(top5_array)
        top10_counts = self.get_counts(top10_array)

        top1_more_than_1 = dict([(k,v) for k,v in top1_counts.items() if v > 1])
        top1_exactly_1 = dict([(k,v) for k,v in top1_counts.items() if v == 1])
        top5_more_than_1 = dict([(k,v) for k,v in top5_counts.items() if v > 1])
        top10_more_than_1 = dict([(k,v) for k,v in top10_counts.items() if v > 1])
        tot_num = target_embeds.shape[0]
        top1_more_than_1_truncated = dict(sorted(top1_more_than_1.items(), key=lambda item: item[1], reverse=True)[:100])
        tot_values_100 = sum(top1_more_than_1_truncated.values())
        tot_values_500 = self.top_k_tot_values(top1_more_than_1, 500)
        tot_values_50 = self.top_k_tot_values(top1_more_than_1, 50)
        tot_values_10 = self.top_k_tot_values(top1_more_than_1, 10)
        tot_values_5 = self.top_k_tot_values(top1_more_than_1, 5)
        tot_values_3 = self.top_k_tot_values(top1_more_than_1, 3)
        print(f'top1_more_than_1_truncated: {top1_more_than_1_truncated}')
        print(f'tot_val_3: {tot_values_3}, tot_val_5: {tot_values_5}, tot_val_10: {tot_values_10}, tot_val_50: {tot_values_50}, tot_val_100: {tot_values_100}, tot_val_500: {tot_values_500}')
        print(f'tot_sum_top1: {sum(top1_counts.values())}')
        print(f'proportion of appearing more than 1 in top1: {len(top1_more_than_1.keys()) / tot_num}, number: {len(top1_more_than_1.keys())}')
        print(f'proportion of appearing exactly once in top1: {len(top1_exactly_1.keys()) / tot_num}, number: {len(top1_exactly_1.keys())}')
        print(f'proportion of appearing more than 1 in top5: {len(top5_more_than_1.keys()) / tot_num}')
        print(f'proportion of appearing more than 1 in top10: {len(top10_more_than_1.keys()) / tot_num}')
        print(f'number of test_linked_ent: {len(self.kgs.test_linked_entities1)}, number of test_unlinked_ent: {len(self.kgs.test_unlinked_entities1)}')




    def save_model(self):
        print('saving the model...')
        saver = tf.train.Saver()
        save_path = saver.save(self.session, self.out_folder+f's{self.args.s2t_topk}_t{self.args.t2s_topk}_model.ckpt')
        print(f'Model saved in {save_path}')

    def load_model(self):
        print('loading the model...')
        saver = tf.train.Saver()
        saver.restore(self.session, self.args.model_dir)

    def test_full_scope(self):
        print("\ntesting synthetic alignment...")
        if self.args.detection_mode == "margin":
            embeds1, embeds2, mapping = self._eval_test_embeddings()
        else:
            embeds1, embeds2, mapping = self._eval_test_embeddings()
        _, _, _, sim_list = test(embeds1, embeds2, mapping, self.args.top_k, self.args.test_threads_num,
                                 metric=self.args.eval_metric, normalize=self.args.eval_norm, csls_k=0, accurate=True)
        print()

        print(f'==== use the mean prob as threshold ====')
        mean_threshold = -1
        if self.args.detection_mode == 'margin':
            self.two_step_evaluation_margin(self.kgs.test_linked_entities1,
                                            self.kgs.test_unlinked_entities1, threshold=mean_threshold, is_test=True)
        print()

        cands = list(np.arange(0.5, 1.00, 0.1))
        for threshold in cands:
            print(f'======= threshold: {threshold} =======')
            if self.args.detection_mode == "margin":
                self.two_step_evaluation_margin(self.kgs.test_linked_entities1,
                                                self.kgs.test_unlinked_entities1, threshold=threshold, is_test=True)
