import gc
import time
import os
import pickle
import random
import math

import numpy as np
import pandas as pd

import ray
import scipy
import tensorflow as tf
import scipy.sparse as sp
from sklearn import preprocessing

from scipy.sparse.linalg.eigen.arpack import eigsh
from openea.modules.utils.util import generate_out_folder

from openea.modules.bootstrapping.alignment_finder import find_alignment
from openea.modules.finding.alignment import greedy_alignment
from openea.modules.utils.util import task_divide, merge_dic
from openea.modules.finding.similarity import sim
from openea.modules.finding.evaluation import early_stop
import openea.modules.load.read as rd
from openea.models.basic_model import BasicModel
from openea.modules.base.optimizers import generate_optimizer
from openea.modules.base.initializers import init_embeddings, orthogonal_init
from mtranse import search_kg1_to_kg2_1nn_neighbor, search_kg1_to_kg2_ordered_all_nns, search_kg2_to_kg1_ordered_all_nns
from utils import compute_my_class_weight, viz_sim_list
import ipdb
from eval import valid, test, greedy_alignment, eval_margin, eval_dangling_binary_cls, get_results


# from nn_search import generate_neighbours


# ***************************adj & sparse**************************
def sparse_to_tuple(sparse_mx):
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        coords = np.vstack((mx.row, mx.col)).transpose()
        values = mx.data
        shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)
    return sparse_mx


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN gnn and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return sparse_to_tuple(adj_normalized)


def chebyshev_polynomials(adj, k):
    """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation)."""
    print("Calculating Chebyshev polynomials up to order {}...".format(k))
    adj_normalized = normalize_adj(adj)
    laplacian = sp.eye(adj.shape[0]) - adj_normalized
    largest_eigval, _ = eigsh(laplacian, 1, which='LM')
    scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0])
    t_k = list()
    t_k.append(sp.eye(adj.shape[0]))
    t_k.append(scaled_laplacian)

    def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap):
        s_lap = sp.csr_matrix(scaled_lap, copy=True)
        return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two

    for i in range(2, k + 1):
        t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian))
    return sparse_to_tuple(t_k)


def func(triples):
    head = {}
    cnt = {}
    for tri in triples:
        if tri[1] not in cnt:
            cnt[tri[1]] = 1
            head[tri[1]] = {tri[0]}
        else:
            cnt[tri[1]] += 1
            head[tri[1]].add(tri[0])
    r2f = {}
    for r in cnt:
        r2f[r] = len(head[r]) / cnt[r]
    return r2f


def ifunc(triples):
    tail = {}
    cnt = {}
    for tri in triples:
        if tri[1] not in cnt:
            cnt[tri[1]] = 1
            tail[tri[1]] = {tri[2]}
        else:
            cnt[tri[1]] += 1
            tail[tri[1]].add(tri[2])
    r2if = {}
    for r in cnt:
        r2if[r] = len(tail[r]) / cnt[r]
    return r2if


def get_weighted_adj(e, triples):
    r2f = func(triples)
    r2if = ifunc(triples)
    M = {}
    for tri in triples:
        if tri[0] == tri[2]:
            continue
        if (tri[0], tri[2]) not in M:
            M[(tri[0], tri[2])] = max(r2if[tri[1]], 0.3)
        else:
            M[(tri[0], tri[2])] += max(r2if[tri[1]], 0.3)
        if (tri[2], tri[0]) not in M:
            M[(tri[2], tri[0])] = max(r2f[tri[1]], 0.3)
        else:
            M[(tri[2], tri[0])] += max(r2f[tri[1]], 0.3)
    row = []
    col = []
    data = []
    for key in M:
        row.append(key[1])
        col.append(key[0])
        data.append(M[key])
    data = np.array(data, dtype='float32')
    return sp.coo_matrix((data, (row, col)), shape=(e, e))


def generate_rel_ht(triples):
    rel_ht_dict = dict()
    for h, r, t in triples:
        hts = rel_ht_dict.get(r, list())
        hts.append((h, t))
        rel_ht_dict[r] = hts
    return rel_ht_dict


def diag_adj(adj):
    d = np.array(adj.sum(1)).flatten()
    d_inv = 1. / d
    d_inv[np.isinf(d_inv)] = 0
    d_inv = sp.diags(d_inv)
    return sparse_to_tuple(d_inv.dot(adj))


def no_weighted_adj(total_ent_num, triple_list, is_two_adj=False):
    start = time.time()
    edge = dict()
    for item in triple_list:
        if item[0] not in edge.keys():
            edge[item[0]] = set()
        if item[2] not in edge.keys():
            edge[item[2]] = set()
        edge[item[0]].add(item[2])
        edge[item[2]].add(item[0])
    row = list()
    col = list()
    for i in range(total_ent_num):
        if i not in edge.keys():
            continue
        key = i
        value = edge[key]
        add_key_len = len(value)
        add_key = (key * np.ones(add_key_len)).tolist()
        row.extend(add_key)
        col.extend(list(value))
    data_len = len(row)
    data = np.ones(data_len)
    one_adj = sp.coo_matrix((data, (row, col)), shape=(total_ent_num, total_ent_num))
    one_adj = preprocess_adj(one_adj)
    print('generating one-adj costs time: {:.4f}s'.format(time.time() - start))
    if not is_two_adj:
        return one_adj, None
    expend_edge = dict()
    row = list()
    col = list()
    temp_len = 0
    for key, values in edge.items():
        if key not in expend_edge.keys():
            expend_edge[key] = set()
        for value in values:
            add_value = edge[value]
            for item in add_value:
                if item not in values and item != key:
                    expend_edge[key].add(item)
                    no_len = len(expend_edge[key])
                    if temp_len != no_len:
                        row.append(key)
                        col.append(item)
                    temp_len = no_len
    data = np.ones(len(row))
    two_adj = sp.coo_matrix((data, (row, col)), shape=(total_ent_num, total_ent_num))
    two_adj = preprocess_adj(two_adj)
    print('generating one- and two-adj costs time: {:.4f}s'.format(time.time() - start))
    return one_adj, two_adj


def relation_adj_list(kg1, kg2, adj_number, all_rel_num, all_ent_num, linked_ents, rel_id_mapping):
    rel_dict = rel_id_mapping
    adj_list = list()
    triple_list = kg1.triple_list + kg2.triple_list
    edge = dict()
    edge_length = np.zeros(all_rel_num)

    for item in triple_list:
        if rel_dict[item[1]] is not None and rel_dict[item[1]] != "":
            edge_id = rel_dict[item[1]]
        else:
            edge_id = item[1]
        if edge_id not in edge.keys():
            edge[edge_id] = list()
        edge[edge_id].append([item[0], item[2]])
        edge_length[edge_id] += 1
    sort_edge_length = np.argsort(-edge_length)

    # **********************************************************************
    adj_len = list()
    for i in range(adj_number):
        pos = np.array(edge[sort_edge_length[i]])
        row, col = np.transpose(pos)
        data = np.ones(shape=int(edge_length[sort_edge_length[i]]))

        adj_len.append(int(edge_length[sort_edge_length[i]]))

        adj = sp.coo_matrix((data, (row, col)), shape=(all_ent_num, all_ent_num))
        adj = sparse_to_tuple(adj)
        adj_list.append(adj)
    return adj_list


def generate_2hop_triples(kg, linked_ents=None):
    triples = kg.triples
    triple_df = np.array([[tr[0], tr[1], tr[2]] for tr in triples])
    triple_df = pd.DataFrame(triple_df, columns=['h', 'r', 't'])
    # print(triple_df)
    two_hop_triple_df = pd.merge(triple_df, triple_df, left_on='t', right_on='h')
    # print(two_hop_triple_df)
    two_step_quadruples = set()
    relation_patterns = dict()
    for index, row in two_hop_triple_df.iterrows():
        head = row["h_x"]
        tail = row["t_y"]
        r_x = row["r_x"]
        r_y = row['r_y']
        if tail not in kg.out_related_ents_dict.get(head, set()) and \
                head not in kg.in_related_ents_dict.get(tail, set()):
            relation_patterns[(r_x, r_y)] = relation_patterns.get((r_x, r_y), 0) + 1
            two_step_quadruples.add((head, r_x, r_y, tail))
    print("total 2-hop neighbors:", len(two_step_quadruples))
    print("total 2-hop relation patterns:", len(relation_patterns))
    relation_patterns = sorted(relation_patterns.items(), key=lambda x: x[1], reverse=True)
    p = 0.05
    num = int(p * len(relation_patterns))
    selected_patterns = set()
    # for i in range(20, num):
    for i in range(5, len(relation_patterns)):
        pattern = relation_patterns[i][0]
        selected_patterns.add(pattern)
    print("selected relation patterns:", len(selected_patterns))
    two_step_triples = set()
    for head, rx, ry, tail in two_step_quadruples:
        if (rx, ry) in selected_patterns:
            two_step_triples.add((head, 0, head))
            two_step_triples.add((head, rx + ry, tail))
    print("selected 2-hop neighbors:", len(two_step_triples))
    return two_step_triples


def transloss_add2hop(kg1, kg2, sup_ent1, sup_ent2, ref_ent1, ref_ent2, total_e_num):
    linked_ents = set(sup_ent1 + sup_ent2 + ref_ent1 + ref_ent2)
    enhanced_triples1 = generate_2hop_triples(kg1, linked_ents=linked_ents)
    enhanced_triples2 = generate_2hop_triples(kg2, linked_ents=linked_ents)
    triples = enhanced_triples1 | enhanced_triples2
    edge = dict()
    for item in triples:
        if item[0] not in edge.keys():
            edge[item[0]] = set()
        if item[2] not in edge.keys():
            edge[item[2]] = set()
        edge[item[0]].add(item[2])
        edge[item[2]].add(item[0])
    row = list()
    col = list()
    for i in range(total_e_num):
        if i not in edge.keys():
            continue
        key = i
        value = edge[key]
        add_key_len = len(value)
        add_key = (key * np.ones(add_key_len)).tolist()
        row.extend(add_key)
        col.extend(list(value))
    data_len = len(row)
    data = np.ones(data_len)
    one_adj = sp.coo_matrix((data, (row, col)), shape=(total_e_num, total_e_num))
    one_adj = sparse_to_tuple(one_adj)
    return one_adj


def get_neighbor_dict(out_dict, in_dict):
    dic = dict()
    for key, value in out_dict.items():
        dic[key] = value
    for key, value in in_dict.items():
        values = dic.get(key, set())
        values |= value
        dic[key] = values
    return dic


def get_neighbor_counterparts(neighbors, alignment_dic):
    neighbor_counterparts = set()
    for n in neighbors:
        if n in alignment_dic:
            neighbor_counterparts.add(alignment_dic.get(n))
    return neighbor_counterparts


def check_new_alignment(aligned_pairs, context="check align"):
    if aligned_pairs is None or len(aligned_pairs) == 0:
        print("{}, empty aligned pairs".format(context))
        return
    num = 0
    for x, y in aligned_pairs:
        if x == y:
            num += 1
    print("{}, right align: {}/{}={:.3f}".format(context, num, len(aligned_pairs), num / len(aligned_pairs)))


def update_labeled_alignment_x(pre_labeled_alignment, curr_labeled_alignment, sim_mat):
    check_new_alignment(pre_labeled_alignment, context="before editing (<-)")
    labeled_alignment_dict = dict(pre_labeled_alignment)
    n1, n2 = 0, 0
    for i, j in curr_labeled_alignment:
        if labeled_alignment_dict.get(i, -1) == i and j != i:
            n2 += 1
        if i in labeled_alignment_dict.keys():
            pre_j = labeled_alignment_dict.get(i)
            if pre_j == j:
                continue
            pre_sim = sim_mat[i, pre_j]
            new_sim = sim_mat[i, j]
            if new_sim >= pre_sim:
                if pre_j == i and j != i:
                    n1 += 1
                labeled_alignment_dict[i] = j
        else:
            labeled_alignment_dict[i] = j
    print("update wrongly: ", n1, "greedy update wrongly: ", n2)
    pre_labeled_alignment = set(zip(labeled_alignment_dict.keys(), labeled_alignment_dict.values()))
    check_new_alignment(pre_labeled_alignment, context="after editing (<-)")
    return pre_labeled_alignment


def update_labeled_alignment_y(labeled_alignment, sim_mat):
    labeled_alignment_dict = dict()
    updated_alignment = set()
    for i, j in labeled_alignment:
        i_set = labeled_alignment_dict.get(j, set())
        i_set.add(i)
        labeled_alignment_dict[j] = i_set
    for j, i_set in labeled_alignment_dict.items():
        if len(i_set) == 1:
            for i in i_set:
                updated_alignment.add((i, j))
        else:
            max_i = -1
            max_sim = -10
            for i in i_set:
                if sim_mat[i, j] > max_sim:
                    max_sim = sim_mat[i, j]
                    max_i = i
            updated_alignment.add((max_i, j))
    check_new_alignment(updated_alignment, context="after editing (->)")
    return updated_alignment


def enhance_triples(kg1, kg2, ents1, ents2):
    assert len(ents1) == len(ents2)
    print("before enhanced:", len(kg1.triples), len(kg2.triples))
    enhanced_triples1, enhanced_triples2 = set(), set()
    links1 = dict(zip(ents1, ents2))
    links2 = dict(zip(ents2, ents1))
    for h1, r1, t1 in kg1.triples:
        h2 = links1.get(h1, None)
        t2 = links1.get(t1, None)
        if h2 is not None and t2 is not None and t2 not in kg2.out_related_ents_dict.get(h2, set()):
            enhanced_triples2.add((h2, r1, t2))
    for h2, r2, t2 in kg2.triples:
        h1 = links2.get(h2, None)
        t1 = links2.get(t2, None)
        if h1 is not None and t1 is not None and t1 not in kg1.out_related_ents_dict.get(h1, set()):
            enhanced_triples1.add((h1, r2, t1))
    print("after enhanced:", len(enhanced_triples1), len(enhanced_triples2))
    return enhanced_triples1, enhanced_triples2


def dropout(inputs, drop_rate, noise_shape, is_sparse):
    if not is_sparse:
        return tf.nn.dropout(inputs, drop_rate)
    return sparse_dropout(inputs, drop_rate, noise_shape)


def sparse_dropout(x, drop_rate, noise_shape):
    """
    Dropout for sparse tensors.
    """
    keep_prob = 1 - drop_rate
    random_tensor = keep_prob
    random_tensor += tf.random.uniform(noise_shape)
    dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
    pre_out = tf.sparse.retain(x, dropout_mask)
    return pre_out * (1. / keep_prob)


def generate_neighbours(entity_embeds1, entity_list1, entity_embeds2, entity_list2, neighbors_num, threads_num=4):
    ent_frags = task_divide(np.array(entity_list1, dtype=np.int32), threads_num)
    ent_frag_indexes = task_divide(np.array(range(len(entity_list1)), dtype=np.int32), threads_num)
    dic = dict()
    rest = []
    for i in range(len(ent_frags)):
        res = find_neighbours.remote(ent_frags[i], entity_embeds1[ent_frag_indexes[i], :],
                                     np.array(entity_list2, dtype=np.int32),
                                     entity_embeds2, neighbors_num)
        rest.append(res)
    for res in ray.get(rest):
        dic = merge_dic(dic, res)
    gc.collect()
    return dic


@ray.remote(num_cpus=1)
def find_neighbours(frags, sub_embed1, entity_list2, embed2, k):
    dic = dict()
    sim_mat = np.matmul(sub_embed1, embed2.T)
    for i in range(sim_mat.shape[0]):
        sort_index = np.argpartition(-sim_mat[i, :], k)
        neighbors_index = sort_index[0:k]
        neighbors = entity_list2[neighbors_index].tolist()
        dic[frags[i]] = neighbors
    del sim_mat
    return dic


class AKG:
    def __init__(self, triples, ori_triples=None):
        self.triples = set(triples)
        self.triple_list = list(self.triples)
        self.triples_num = len(self.triples)

        self.heads = set([triple[0] for triple in self.triple_list])
        self.props = set([triple[1] for triple in self.triple_list])
        self.tails = set([triple[2] for triple in self.triple_list])
        self.ents = self.heads | self.tails

        print("triples num", self.triples_num)

        print("head ent num", len(self.heads))
        print("total ent num", len(self.ents))

        self.prop_list = list(self.props)
        self.ent_list = list(self.ents)
        self.prop_list.sort()
        self.ent_list.sort()

        if ori_triples is None:
            self.ori_triples = None
        else:
            self.ori_triples = set(ori_triples)

        self._generate_related_ents()
        self._generate_triple_dict()
        self._generate_ht()
        self.__generate_weight()

    def _generate_related_ents(self):
        self.out_related_ents_dict = dict()
        self.in_related_ents_dict = dict()
        for h, r, t in self.triple_list:
            out_related_ents = self.out_related_ents_dict.get(h, set())
            out_related_ents.add(t)
            self.out_related_ents_dict[h] = out_related_ents

            in_related_ents = self.in_related_ents_dict.get(t, set())
            in_related_ents.add(h)
            self.in_related_ents_dict[t] = in_related_ents

    def _generate_triple_dict(self):
        self.rt_dict, self.hr_dict = dict(), dict()
        for h, r, t in self.triple_list:
            rt_set = self.rt_dict.get(h, set())
            rt_set.add((r, t))
            self.rt_dict[h] = rt_set
            hr_set = self.hr_dict.get(t, set())
            hr_set.add((h, r))
            self.hr_dict[t] = hr_set

    def _generate_ht(self):
        self.ht = set()
        for h, r, t in self.triples:
            self.ht.add((h, t))

    def __generate_weight(self):
        triple_num = dict()
        n = 0
        for h, r, t in self.triples:
            if t in self.heads:
                n = n + 1
                triple_num[h] = triple_num.get(h, 0) + 1
                triple_num[t] = triple_num.get(t, 0) + 1
        self.weighted_triples = list()
        self.additional_triples = list()
        ave = math.ceil(n / len(self.heads))
        print("ave outs:", ave)

        for h, r, t in self.triples:
            w = 1
            if t in self.heads and triple_num[h] <= ave:
                w = 2.0
                self.additional_triples.append((h, r, t))
            self.weighted_triples.append((h, r, t, w))
        print("additional triples:", len(self.additional_triples))


class GraphConvolution:
    def __init__(self, input_dim, output_dim, adj,
                 num_features_nonzero,
                 dropout_rate=0.0,
                 name='GCN',
                 is_sparse_inputs=False,
                 activation=tf.tanh,
                 use_bias=True):
        self.activation = activation
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.adjs = [tf.SparseTensor(indices=am[0], values=am[1], dense_shape=am[2]) for am in adj]
        self.num_features_nonzero = num_features_nonzero
        self.dropout_rate = dropout_rate
        self.is_sparse_inputs = is_sparse_inputs
        self.use_bias = use_bias
        self.kernels = list()
        self.bias = list()
        self.name = name
        self.data_type = tf.float32
        self._get_variable()

    def _get_variable(self):
        self.batch_normalization = tf.keras.layers.BatchNormalization()
        for i in range(len(self.adjs)):
            self.kernels.append(tf.get_variable(self.name + '_kernel_' + str(i),
                                                shape=(self.input_dim, self.output_dim),
                                                initializer=tf.glorot_uniform_initializer(),
                                                regularizer=tf.contrib.layers.l2_regularizer(scale=0.01),
                                                dtype=self.data_type))
        if self.use_bias:
            self.bias = tf.get_variable(self.name + '_bias', shape=[self.output_dim, ],
                                        initializer=tf.zeros_initializer(),
                                        dtype=self.data_type)

    def call(self, inputs):
        inputs = self.batch_normalization(inputs)
        if self.dropout_rate > 0.0:
            inputs = dropout(inputs, self.dropout_rate, self.num_features_nonzero, self.is_sparse_inputs)
        hidden_vectors = list()
        for i in range(len(self.adjs)):
            pre_sup = tf.matmul(inputs, self.kernels[i], a_is_sparse=self.is_sparse_inputs)
            hidden_vector = tf.sparse_tensor_dense_matmul(tf.cast(self.adjs[i], tf.float32), pre_sup)
            hidden_vectors.append(hidden_vector)
        outputs = tf.add_n(hidden_vectors)
        # bias
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, self.bias)
        # activation
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def update_adj(self, adj):
        print("gcn update adj...")
        self.adjs = [tf.SparseTensor(indices=am[0], values=am[1], dense_shape=am[2]) for am in adj]


class HighwayLayer:
    def __init__(self, input_dim, output_dim, dropout_rate=0.0, name="highway"):
        self.input_shape = (input_dim, output_dim)
        self.name = name
        self.data_type = tf.float32
        self.dropout_rate = dropout_rate
        self._get_variable()

    def _get_variable(self):
        self.weight = tf.get_variable(self.name + 'kernel', shape=self.input_shape,
                                      initializer=tf.glorot_uniform_initializer(),
                                      regularizer=tf.contrib.layers.l2_regularizer(scale=0.01),
                                      dtype=self.data_type)
        self.activation = tf.tanh
        self.batch_normal = tf.keras.layers.BatchNormalization()

    def call(self, input1, input2):
        input1 = self.batch_normal(input1)
        input2 = self.batch_normal(input2)
        gate = tf.matmul(input1, self.weight)
        gate = self.activation(gate)
        if self.dropout_rate > 0:
            gate = tf.nn.dropout(gate, self.dropout_rate)
        gate = tf.keras.activations.relu(gate)
        output = tf.add(tf.multiply(input2, 1 - gate), tf.multiply(input1, gate))
        return self.activation(output)


class AliNetGraphAttentionLayer:
    def __init__(self, input_dim, output_dim, adj, nodes_num,
                 dropout_rate, is_sparse_input=False, use_bias=True,
                 activation=None, name="alinet"):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.adjs = [tf.SparseTensor(indices=adj[0][0], values=adj[0][1], dense_shape=adj[0][2])]
        self.dropout_rate = dropout_rate
        self.is_sparse_input = is_sparse_input
        self.nodes_num = nodes_num
        self.use_bias = use_bias
        self.activation = activation
        self.name = name
        self.data_type = tf.float32
        self._get_variable()

    def _get_variable(self):
        self.kernel = tf.get_variable(self.name + '_kernel', shape=(self.input_dim, self.output_dim),
                                      initializer=tf.glorot_uniform_initializer(),
                                      regularizer=tf.contrib.layers.l2_regularizer(scale=0.01),
                                      dtype=self.data_type)
        self.kernel1 = tf.get_variable(self.name + '_kernel_1', shape=(self.input_dim, self.input_dim),
                                       initializer=tf.glorot_uniform_initializer(),
                                       regularizer=tf.contrib.layers.l2_regularizer(scale=0.01),
                                       dtype=self.data_type)
        self.kernel2 = tf.get_variable(self.name + '_kernel_2', shape=(self.input_dim, self.input_dim),
                                       initializer=tf.glorot_uniform_initializer(),
                                       regularizer=tf.contrib.layers.l2_regularizer(scale=0.01),
                                       dtype=self.data_type)
        self.batch_normlization = tf.keras.layers.BatchNormalization()

    def call(self, inputs):
        inputs = self.batch_normlization(inputs)
        mapped_inputs = tf.matmul(inputs, self.kernel)
        attention_inputs1 = tf.matmul(inputs, self.kernel1)
        attention_inputs2 = tf.matmul(inputs, self.kernel2)
        con_sa_1 = tf.reduce_sum(tf.multiply(attention_inputs1, inputs), 1, keepdims=True)
        con_sa_2 = tf.reduce_sum(tf.multiply(attention_inputs2, inputs), 1, keepdims=True)
        con_sa_1 = tf.keras.activations.tanh(con_sa_1)
        con_sa_2 = tf.keras.activations.tanh(con_sa_2)
        if self.dropout_rate > 0.0:
            con_sa_1 = tf.nn.dropout(con_sa_1, self.dropout_rate)
            con_sa_2 = tf.nn.dropout(con_sa_2, self.dropout_rate)
        con_sa_1 = tf.cast(self.adjs[0], dtype=tf.float32) * con_sa_1
        con_sa_2 = tf.cast(self.adjs[0], dtype=tf.float32) * tf.transpose(con_sa_2, [1, 0])
        weights = tf.sparse_add(con_sa_1, con_sa_2)
        weights = tf.SparseTensor(indices=weights.indices,
                                  values=tf.nn.leaky_relu(weights.values),
                                  dense_shape=weights.dense_shape)
        attention_adj = tf.sparse_softmax(weights)
        attention_adj = tf.sparse_reshape(attention_adj, shape=[self.nodes_num, self.nodes_num])
        value = tf.sparse_tensor_dense_matmul(attention_adj, mapped_inputs)
        return self.activation(value)


class AliNet(BasicModel):

    def set_kgs(self, kgs):
        self.kgs = kgs
        self.kg1 = AKG(self.kgs.kg1.relation_triples_set)
        self.kg2 = AKG(self.kgs.kg2.relation_triples_set)

    def set_args(self, args):
        self.args = args
        self.out_folder = generate_out_folder(self.args.output, self.args.training_data, self.args.dataset_division,
                                              self.__class__.__name__)

    def init(self):
        self.ref_ent1 = self.kgs.test_entities1 + self.kgs.valid_entities1
        self.ref_ent2 = self.kgs.test_entities2 + self.kgs.valid_entities2
        self.sup_ent1 = self.kgs.train_entities1
        self.sup_ent2 = self.kgs.train_entities2
        self.linked_ents = set(self.kgs.train_entities1 +
                               self.kgs.train_entities2 +
                               self.kgs.valid_entities1 +
                               self.kgs.test_entities1 +
                               self.kgs.test_entities2 +
                               self.kgs.valid_entities2)
        enhanced_triples1, enhanced_triples2 = enhance_triples(self.kg1,
                                                               self.kg2,
                                                               self.sup_ent1,
                                                               self.sup_ent2)
        ori_triples = self.kg1.triple_list + self.kg2.triple_list
        triples = ori_triples + list(enhanced_triples1) + list(enhanced_triples2)

        rel_ht_dict = generate_rel_ht(triples)
        saved_data_path = self.args.training_data + 'alinet_' + self.args.align_direction + 'saved_data.pkl'
        if os.path.exists(saved_data_path):
            print('load saved adj data from', saved_data_path)
            adj = pickle.load(open(saved_data_path, 'rb'))
        else:
            one_adj, _ = no_weighted_adj(self.kgs.entities_num, triples, is_two_adj=False)
            adj = [one_adj]
            if self.is_two:
                two_hop_triples1 = generate_2hop_triples(self.kg1, linked_ents=self.linked_ents)
                two_hop_triples2 = generate_2hop_triples(self.kg2, linked_ents=self.linked_ents)
                triples = two_hop_triples1 | two_hop_triples2
                two_adj, _ = no_weighted_adj(self.kgs.entities_num, triples, is_two_adj=False)
                adj.append(two_adj)
            print('save adj data to', saved_data_path)
            pickle.dump(adj, open(saved_data_path, 'wb'))

        self.adj = adj
        self.ori_adj = [adj[0]]
        self.rel_ht_dict = rel_ht_dict
        self.rel_win_size = self.args.min_rel_win

        sup_ent1 = np.array(self.sup_ent1).reshape((len(self.sup_ent1), 1))
        sup_ent2 = np.array(self.sup_ent2).reshape((len(self.sup_ent1), 1))
        weight = np.ones((len(self.kgs.train_entities1), 1), dtype=np.float)
        self.sup_links = np.hstack((sup_ent1, sup_ent2, weight))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.session = tf.Session(config=config)
        self._get_variable()
        if self.args.rel_param > 0.0:
            self._generate_rel_graph()
        else:
            self._generate_graph()

        if self.args.detection_mode == "margin":
            self._define_distance_margin_graph()
        elif self.args.detection_mode == "open":
            self._define_open_margin_graph()

        if self.args.use_NCA_loss:
            self._define_NCA_loss_graph()

        if self.args.use_dangling_bin_cls:
            print(f'self.args.use_dangling_bin_cls: {self.args.use_dangling_bin_cls}')
            self._define_dangling_bin_cls()

        tf.global_variables_initializer().run(session=self.session)

    def __init__(self):
        super().__init__()
        self.adj = None
        self.one_hop_layers = None
        self.two_hop_layers = None
        self.layers_outputs = None

        self.new_edges1, self.new_edges2 = set(), set()
        self.new_links = set()
        self.pos_link_batch = None
        self.neg_link_batch = None
        self.sup_links_set = set()
        self.rel_ht_dict = None
        self.rel_win_size = None
        self.start_augment = None
        self.is_two = False
        self.new_sup_links_set = set()
        self.input_embeds, self.output_embeds_list = None, None
        self.sup_links = None
        self.model = None
        self.optimizer = None
        self.ref_ent1 = None
        self.ref_ent2 = None
        self.sup_ent1 = None
        self.sup_ent2 = None
        self.linked_ents = None
        self.session = None

    def _get_variable(self):
        self.init_embedding = tf.get_variable('init_embedding',
                                              shape=(self.kgs.entities_num, self.args.layer_dims[0]),
                                              initializer=tf.glorot_uniform_initializer(),
                                              dtype=tf.float32)

    def _define_model(self):
        print('Getting AliNet model...')
        layer_num = len(self.args.layer_dims) - 1
        output_embeds = self.init_embedding
        one_layers = list()
        two_layers = list()
        layers_outputs = list()
        for i in range(layer_num):
            gcn_layer = GraphConvolution(input_dim=self.args.layer_dims[i],
                                         output_dim=self.args.layer_dims[i + 1],
                                         adj=[self.adj[0]],
                                         num_features_nonzero=self.args.num_features_nonzero,
                                         dropout_rate=0.0,
                                         name='gcn_' + str(i))
            one_layers.append(gcn_layer)
            one_output_embeds = gcn_layer.call(output_embeds)
            output_embeds = one_output_embeds
            layers_outputs.append(output_embeds)

        self.one_hop_layers = one_layers
        self.two_hop_layers = two_layers
        self.output_embeds_list = layers_outputs

    def compute_loss(self, pos_links, neg_links, only_pos=False):
        index1 = pos_links[:, 0]
        index2 = pos_links[:, 1]
        neg_index1 = neg_links[:, 0]
        neg_index2 = neg_links[:, 1]

        embeds_list = list()
        for output_embeds in self.output_embeds_list + [self.init_embedding]:
            output_embeds = tf.nn.l2_normalize(output_embeds, 1)
            embeds_list.append(output_embeds)
        output_embeds = tf.concat(embeds_list, axis=1)
        output_embeds = tf.nn.l2_normalize(output_embeds, 1)

        embeds1 = tf.nn.embedding_lookup(output_embeds, tf.cast(index1, tf.int32))
        embeds2 = tf.nn.embedding_lookup(output_embeds, tf.cast(index2, tf.int32))
        pos_loss = tf.reduce_sum(tf.reduce_sum(tf.square(embeds1 - embeds2), 1))

        embeds1 = tf.nn.embedding_lookup(output_embeds, tf.cast(neg_index1, tf.int32))
        embeds2 = tf.nn.embedding_lookup(output_embeds, tf.cast(neg_index2, tf.int32))
        neg_distance = tf.reduce_sum(tf.square(embeds1 - embeds2), 1)
        neg_loss = tf.reduce_sum(tf.keras.activations.relu(self.args.neg_margin - neg_distance))

        return pos_loss + self.args.neg_margin_balance * neg_loss

    def compute_rel_loss(self, hs, ts):
        embeds_list = list()
        for output_embeds in self.output_embeds_list + [self.init_embedding]:
            output_embeds = tf.nn.l2_normalize(output_embeds, 1)
            embeds_list.append(output_embeds)
        output_embeds = tf.concat(embeds_list, axis=1)
        output_embeds = tf.nn.l2_normalize(output_embeds, 1)
        h_embeds = tf.nn.embedding_lookup(output_embeds, tf.cast(hs, tf.int32))
        t_embeds = tf.nn.embedding_lookup(output_embeds, tf.cast(ts, tf.int32))
        r_temp_embeds = tf.reshape(h_embeds - t_embeds, [-1, self.rel_win_size, output_embeds.shape[-1]])
        r_temp_embeds = tf.reduce_mean(r_temp_embeds, axis=1, keepdims=True)
        r_embeds = tf.tile(r_temp_embeds, [1, self.rel_win_size, 1])
        r_embeds = tf.reshape(r_embeds, [-1, output_embeds.shape[-1]])
        r_embeds = tf.nn.l2_normalize(r_embeds, 1)
        return tf.reduce_sum(tf.reduce_sum(tf.square(h_embeds - t_embeds - r_embeds), 1)) * self.args.rel_param

    def _generate_graph(self):
        self.pos_links = tf.placeholder(tf.int32, shape=[None, 3], name="pos")
        self.neg_links = tf.placeholder(tf.int32, shape=[None, 2], name='neg')
        self._define_model()
        self.loss = self.compute_loss(self.pos_links, self.neg_links)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.args.learning_rate).minimize(self.loss)

    def _generate_rel_graph(self):
        print('Building relational embedding graph...')
        print("rel_win_size:", self.rel_win_size)
        self.pos_links = tf.placeholder(tf.int32, shape=[None, 3], name="pos")
        self.neg_links = tf.placeholder(tf.int32, shape=[None, 2], name='neg')
        self.hs = tf.placeholder(tf.int32, shape=[None], name="hs")
        self.ts = tf.placeholder(tf.int32, shape=[None], name='ts')
        self._define_model()
        self.loss = self.compute_loss(self.pos_links, self.neg_links) + \
                    self.compute_rel_loss(self.hs, self.ts)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.args.learning_rate).minimize(self.loss)

    def _define_distance_margin_graph(self):
        print("build distance margin graph...")
        with tf.name_scope('entity_placeholder'):
            self.input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.input_ents2 = tf.placeholder(tf.int32, shape=[None])
        with tf.name_scope('entity_lookup'):
            # x1 = tf.nn.embedding_lookup(self.init_embedding, self.input_ents1)
            # x2 = tf.nn.embedding_lookup(self.init_embedding, self.input_ents2)
            x1 = self.lookup_embeds(self.input_ents1)
            x2 = self.lookup_embeds(self.input_ents2)
            x1 = tf.nn.l2_normalize(x1, 1)
            x2 = tf.nn.l2_normalize(x2, 1)
        with tf.name_scope('dis_margin_loss'):
            dis = tf.reduce_sum(tf.square(x1 - x2), axis=1)
        self.dis_loss = tf.reduce_sum(tf.nn.relu(self.args.distance_margin - dis))
        self.dis_optimizer = tf.train.AdamOptimizer(learning_rate=self.args.learning_rate).minimize(self.dis_loss)
        print('finished')

    def _define_open_margin_graph(self):
        print("build open margin graph...")
        with tf.name_scope('entity_placeholder'):
            self.input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.input_ents2 = tf.placeholder(tf.int32, shape=[None, self.args.num_random])
        with tf.name_scope('entity_lookup'):
            x1 = tf.nn.embedding_lookup(self.init_embedding, self.input_ents1)
            x1 = tf.nn.l2_normalize(x1, 1)
            x1 = tf.reshape(x1, [-1, 1, self.args.layer_dims[0]])
            x2 = tf.nn.embedding_lookup(self.init_embedding, self.input_ents2)
            x2 = tf.nn.l2_normalize(x2, 1)
        with tf.name_scope('open_margin_loss'):
            dis_mat = tf.reduce_sum(tf.square(x1 - x2), axis=2)
            avg_dis = tf.reduce_mean(dis_mat, axis=1, keepdims=True)
            dis = dis_mat - avg_dis
        self.open_loss = tf.reduce_mean(tf.abs(dis))
        self.open_optimizer = tf.train.AdamOptimizer(learning_rate=self.args.learning_rate).minimize(self.open_loss)

    def _define_dangling_bin_cls(self):
        print('build dangling entity binary classification graph...')
        with tf.name_scope('is_training'):
            self.is_training = tf.placeholder(tf.bool, None)
        with tf.name_scope('topk_scores'):
            self.topk_scores = tf.placeholder(tf.float32, shape=[None,
                                                                 self.args.s2t_topk * self.args.t2s_topk])  # (#source_ent, k_sim_scores)
        if self.args.concat_forward_reverse:
            print('consider s2t nearest topk scores')
            with tf.name_scope('s2t_topk_scores'):
                self.s2t_scores = tf.placeholder(tf.float32, shape=[None, self.args.s2t_topk])
        with tf.name_scope('bin_cls_labels'):
            self.bin_cls_labels = tf.placeholder(tf.float32, shape=[None])  # (#source_ent)
        with tf.name_scope('bin_cls_mlp'):
            if self.args.concat_forward_reverse:
                input = tf.concat([self.topk_scores, self.s2t_scores], 1)
            else:
                input = self.topk_scores
            fc1 = tf.contrib.layers.fully_connected(input, self.args.n_hid, activation_fn=tf.nn.relu)
            # fc2 = tf.contrib.layers.fully_connected(fc1, 64, activation_fn=tf.nn.relu)
            fc1 = tf.layers.dropout(fc1, rate=0.5, training=self.is_training)
            output = tf.contrib.layers.fully_connected(fc1, 1, activation_fn=None)
            output = tf.sigmoid(output)
            self.dangling_bin_cls_output = tf.reshape(output, (-1,))

        self.bin_cls_loss = tf.keras.losses.binary_crossentropy(self.bin_cls_labels, self.dangling_bin_cls_output)
        self.bin_cls_optimizer = generate_optimizer(self.bin_cls_loss, 2 * self.args.learning_rate, opt=self.args.optimizer)

    def _define_adversarial_graph(self):
        print('build graph for WGAN')

        def glorot_init(shape):
            return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))

        def discriminator(x, weights, biases):
            hidden_layer = tf.matmul(x, weights['disc_hidden1'])
            hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
            hidden_layer = tf.nn.relu(hidden_layer)
            out_layer = tf.matmul(hidden_layer, weights['disc_out'])
            out_layer = tf.add(out_layer, biases['disc_out'])
            # out_layer = tf.nn.sigmoid(out_layer)
            return out_layer

        with tf.name_scope('discriminator'):
            self.disc_input_ent1 = tf.placeholder(tf.int32, shape=[None])
            self.disc_input_ent2 = tf.placeholder(tf.int32, shape=[None])
            self.input_dangling_ent1 = tf.placeholder(tf.int32, shape=[None])
            weights1 = {
                'disc_hidden1': tf.Variable(glorot_init([self.args.dim, self.args.adver_n_hid])),
                'disc_out': tf.Variable(glorot_init([self.args.adver_n_hid, 1])),
            }
            biases1 = {
                'disc_hidden1': tf.Variable(tf.zeros([self.args.adver_n_hid])),
                'disc_out': tf.Variable(tf.zeros([1])),
            }
            input_embeds1 = tf.nn.embedding_lookup(self.init_embedding, self.disc_input_ent1)
            embeds2 = tf.nn.embedding_lookup(self.init_embedding, self.disc_input_ent2)
            embeds1 = tf.matmul(input_embeds1, self.mapping_mat)
            embeds_Mh = tf.nn.l2_normalize(embeds1, 1)

            input_dangling_ent1 = tf.nn.embedding_lookup(self.init_embedding, self.input_dangling_ent1)
            dangling_embeds1 = tf.matmul(input_dangling_ent1, self.mapping_mat)
            dangling_embeds1 = tf.nn.l2_normalize(dangling_embeds1, 1)

            self.disc_fake = discriminator(embeds_Mh, weights1, biases1)
            self.disc_real = discriminator(embeds2, weights1, biases1)
            self.dangling_scores = discriminator(dangling_embeds1, weights1, biases1)

            self.gen_loss = -tf.reduce_mean(self.disc_fake)
            if self.args.max_dangling_to_target:
                self.max_dangling_loss = self.args.lambda_loss * tf.reduce_mean(self.dangling_scores)
                # self.gen_loss += self.max_dangling_loss
            self.disc_loss = - tf.reduce_mean(self.disc_real) + tf.reduce_mean(self.disc_fake)

            disc_vars = [weights1['disc_hidden1'], weights1['disc_out'],
                         biases1['disc_hidden1'], biases1['disc_out']]

            tmp_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='KGembeddings')

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                self.gen_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.gen_loss,
                                                                                      var_list=tmp_list + [
                                                                                          self.mapping_mat])
                self.disc_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(self.disc_loss,
                                                                                       var_list=disc_vars)
                if self.args.max_dangling_to_target:
                    self.max_dangling_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(
                        self.max_dangling_loss, var_list=tmp_list + [self.mapping_mat])

            self.disc_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in disc_vars]

    def _NCA_loss(self, embed1, embed2):
        # embed1 = tf.matmul(embed1, self.mapping_mat)
        # embed1 = tf.nn.l2_normalize(embed1, 1)
        self.NCA_embed1 = embed1
        with tf.name_scope('NCA_loss'):
            sim_scores = tf.matmul(embed1, tf.transpose(embed2))
            # S_diags = tf.eye(bsize) * sim_scores
            sim_diags = tf.diag(tf.linalg.tensor_diag_part(sim_scores))
            S_ = tf.exp(self.alpha * (sim_scores - self.ep))
            S_ = S_ - tf.diag(tf.linalg.tensor_diag_part(S_)) # clear diagonal
            self.S_ = S_

            # loss_diag = -tf.log(1 + tf.nn.relu(tf.reduce_sum(sim_diags, 0)))
            loss_diag = -tf.log(1 + self.beta * tf.nn.relu(tf.reduce_sum(sim_diags, 0)))
            pos_scores = tf.log(1 + tf.reduce_sum(S_, 0)) / self.alpha
            neg_scores = tf.log(1 + tf.reduce_sum(S_, 1)) / self.alpha
            self.pos_scores = pos_scores
            self.neg_scores = neg_scores
            self.in_pos_scores = 1 + tf.reduce_sum(S_, 0)
            # loss = tf.reduce_mean(pos_scores
            #                       + neg_scores
            #                       + loss_diag * self.beta)
            loss = tf.reduce_mean(pos_scores
                                    + neg_scores
                                    + loss_diag)
        return loss

    def _define_NCA_loss_graph(self):
        print("build NCA loss graph...")
        with tf.name_scope('NCA_input_placeholder'):
            # self.NCA_bsize = tf.placeholder(tf.int32, shape=[])
            self.NCA_input_ents1 = tf.placeholder(tf.int32, shape=[None])
            self.NCA_input_ents2 = tf.placeholder(tf.int32, shape=[None])
            self.alpha = tf.placeholder(tf.float32, shape=[])
            self.beta = tf.placeholder(tf.float32, shape=[])
            self.ep = tf.placeholder(tf.float32, shape=[])

        with tf.name_scope('NCA_lookup'):
            embed1 = tf.nn.embedding_lookup(self.init_embedding, self.NCA_input_ents1)
            embed2 = tf.nn.embedding_lookup(self.init_embedding, self.NCA_input_ents2)
            # embed1 = self.lookup_embeds(self.NCA_input_ents1)
            # embed2 = self.lookup_embeds(self.NCA_input_ents2)
            embed1 = tf.nn.l2_normalize(embed1, 1)
            embed2 = tf.nn.l2_normalize(embed2, 1)
            self.NCA_embed2 = embed2
        with tf.name_scope('NCA_loss'):
            self.NCA_loss = self._NCA_loss(embed1, embed2)
            self.NCA_optimizer = generate_optimizer(self.NCA_loss, self.args.learning_rate,
                                                    opt=self.args.optimizer)

    def get_source_and_candidates(self, source_ents_and_labels, is_test):
        # total_ent_embeds = self.lookup_embeds(None).eval(session=self.session)
        total_ent_embeds = self.lookup_last_embeds(None).eval(session=self.session)

        source_ents = [x[0] for x in source_ents_and_labels]
        source_embeds = total_ent_embeds[np.array(source_ents),]

        if is_test:
            target_candidates = list(set(self.kgs.kg2.entities_list) -
                                     set(self.kgs.train_entities2) -
                                     set(self.kgs.valid_entities2))
        else:
            target_candidates = list(set(self.kgs.kg2.entities_list) - set(self.kgs.train_entities2))
        target_embeds = total_ent_embeds[np.array(target_candidates),]
        source_ent_y = [x[1] for x in source_ents_and_labels]
        return source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, None

    def evaluate_margin(self, source_ents_and_labels, margin, is_test=False):
        print("dangling entity detection...")

        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(source_ents_and_labels, is_test)
        nns, sims = search_kg1_to_kg2_1nn_neighbor(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                   return_sim=True, soft_nn=1)
        dis_vec = 1 - np.array(sims)
        mean_dis = np.mean(dis_vec)
        print(mean_dis, dis_vec)
        dis_list = dis_vec.tolist()

        return eval_margin(source_ents, dis_list, source_ent_y, margin=mean_dis)

    def eval_dangling_detection_cls(self, source_ents_and_labels, threshold=0.5, is_test=False):
        # TODO change ent_embeds
        print("using binary classification for determining whether it's dangling ")
        source_embeds, source_ents, source_ent_y, target_embeds, target_candidates, mapping_mat = \
            self.get_source_and_candidates(source_ents_and_labels, is_test)
        ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(source_embeds, target_embeds, target_candidates, mapping_mat,
                                                           return_all_sim=True,
                                                           soft_nn=self.args.s2t_topk)
        # returned ent12 are indiced by the real indices in the whole embedding table
        total_ent_embeds = self.lookup_last_embeds(None).eval(session=self.session)
        retrieved_target_embeds = total_ent_embeds[
            np.array(ent12, dtype=np.int32),]  # shape (batch_size, s2t_topk, d_emb)
        # ipdb.set_trace()
        reversed_ent21, t2s_sim = search_kg2_to_kg1_ordered_all_nns(retrieved_target_embeds, source_embeds, source_ents,
                                                                    mapping_mat, return_all_sim=True,
                                                                    soft_nn=self.args.t2s_topk)
        t2s_sim = np.array(t2s_sim)
        print(f'len(reversed_ent21): {len(reversed_ent21)}, t2s_sim.shape: {t2s_sim.shape}')
        if self.args.concat_forward_reverse:
            s2t_sim = np.array(s2t_sim)
            print(f's2t_sim.shape: {s2t_sim.shape}')
            y_predict = self.session.run(fetches=self.dangling_bin_cls_output,
                                         feed_dict={self.topk_scores: t2s_sim,
                                                    self.s2t_scores: s2t_sim,
                                                    self.is_training: False})
        else:
            y_predict = self.session.run(fetches=self.dangling_bin_cls_output,
                                         feed_dict={self.topk_scores: t2s_sim,
                                                    self.is_training: False})
        # batch_loss, _ = self.session.run(fetches=[self.bin_cls_loss, self.bin_cls_optimizer],
        #                                  feed_dict={self.topk_scores: t2s_sim,
        #                                             self.bin_cls_labels: batch_labels})
        print(f'type(y_predict): {type(y_predict)}, y_predict: {y_predict}')
        # y_predict = y_predict.eval(session=self.session)
        return eval_dangling_binary_cls(source_ents, y_predict, source_ent_y, threshold)

    def two_step_evaluation_margin(self, matchable_source_ents1, dangling_source_ents1, threshold=-1, is_test=False):
        print("evaluating two-step alignment (margin)...")
        if is_test and self.args.test_batch_num > 1:
            final_label11_ents = list()
            final_label1_num = 0
            final_num_one_one, final_num_one_zero, final_num_zero_zero, final_num_zero_one, final_num_total_examples, \
            final_num_one_labels = 0, 0, 0, 0, 0, 0
            batch_num = self.args.test_batch_num
            print("test via batches...\n")
            matchable_source_ents1_tasks = task_divide(matchable_source_ents1, batch_num)
            dangling_source_ents1_tasks = task_divide(dangling_source_ents1, batch_num)
            for i in range(batch_num):
                if self.args.use_dangling_bin_cls:
                    label11_ents, label1_num, \
                    num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                        self.eval_dangling_detection_cls(matchable_source_ents1_tasks[i] +
                                                         dangling_source_ents1_tasks[i],
                                                         threshold=threshold, is_test=is_test)
                else:
                    label11_ents, label1_num, \
                    num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                        self.evaluate_margin(matchable_source_ents1_tasks[i] +
                                             dangling_source_ents1_tasks[i],
                                             self.args.distance_margin, is_test=is_test)
                final_label11_ents += label11_ents
                final_label1_num += label1_num
                final_num_one_one += num_one_one
                final_num_one_zero += num_one_zero
                final_num_zero_zero += num_zero_zero
                final_num_zero_one += num_zero_one
                final_num_total_examples += num_total_examples
                final_num_one_labels += num_one_labels
                print()
            print("final test results:")
            get_results(final_num_one_one, final_num_one_zero, final_num_zero_zero, final_num_zero_one,
                        final_num_total_examples, final_num_one_labels)
        else:
            if self.args.use_dangling_bin_cls:
                final_label11_ents, final_label1_num, \
                num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                    self.eval_dangling_detection_cls(matchable_source_ents1 + dangling_source_ents1,
                                                     threshold=threshold, is_test=is_test)
            else:
                final_label11_ents, final_label1_num, \
                num_one_one, num_one_zero, num_zero_zero, num_zero_one, num_total_examples, num_one_labels = \
                    self.evaluate_margin(matchable_source_ents1 + dangling_source_ents1,
                                         self.args.distance_margin, is_test=is_test)
        return self.real_entity_alignment_evaluation(final_label11_ents, final_label1_num, matchable_source_ents1)

    def real_entity_alignment_evaluation(self, label11_ents, label1_num, matchable_source_ents1):
        if label11_ents is None or len(label11_ents) == 0:
            print("no predicated matchable entities")
            return 0.
        total_ent_embeds = self.lookup_embeds(None).eval(session=self.session)
        label11_source_embeds = total_ent_embeds[np.array(label11_ents),]

        true_targets = []
        matchable_ents1 = self.kgs.valid_entities1 + self.kgs.test_entities1
        matchable_ents2 = self.kgs.valid_entities2 + self.kgs.test_entities2
        for e in label11_ents:
            idx = matchable_ents1.index(e)
            true_targets.append(matchable_ents2[idx])
        assert len(true_targets) == len(label11_ents)
        candidate_list = true_targets + list(self.kgs.kg2.entities_set
                                             - set(self.kgs.train_entities2)
                                             - set(self.kgs.valid_entities2)
                                             - set(true_targets))
        candidate_embeds = total_ent_embeds[np.array(candidate_list),]
        _, hits, _, _, _ = greedy_alignment(label11_source_embeds, candidate_embeds,
                                            self.args.top_k, self.args.test_threads_num,
                                            self.args.eval_metric, True, 10, False, False)
        hits1 = hits[0]
        hits10 = hits[2]
        precision = hits1 * len(label11_ents) / label1_num
        recall = hits1 * len(label11_ents) / len(matchable_source_ents1)
        f1 = 2 * precision * recall / (precision + recall)
        recall10 = hits10 * len(label11_ents) / len(matchable_source_ents1)
        print("two-step results, precision = {:.3f}, recall = {:.3f}, f1 = {:.3f}, recall@10 = {:.3f}\n".format(
            precision, recall, f1, recall10))
        return f1

    def _eval_valid_embeddings(self):
        ent1 = self.kgs.valid_entities1
        ent2 = self.kgs.valid_entities2 + list(self.kgs.kg2.entities_set
                                               - set(self.kgs.train_entities2)
                                               - set(self.kgs.valid_entities2))
        embeds1 = self.eval_embeddings(ent1)
        embeds2 = self.eval_embeddings(ent2)
        return embeds1, embeds2, None

    def _eval_test_embeddings(self):
        ent1 = self.kgs.test_entities1
        ent2 = self.kgs.test_entities2 + list(self.kgs.kg2.entities_set
                                              - set(self.kgs.train_entities2)
                                              - set(self.kgs.valid_entities2)
                                              - set(self.kgs.test_entities2))
        embeds1 = self.eval_embeddings(ent1)
        embeds2 = self.eval_embeddings(ent2)
        return embeds1, embeds2, None

    def save(self):
        embeds_list = list()
        input_embeds = self.init_embedding
        output_embeds_list = self.output_embeds_list
        for output_embeds in [input_embeds] + output_embeds_list:
            output_embeds = tf.nn.l2_normalize(output_embeds, 1)
            output_embeds = np.array(output_embeds.eval(session=self.session))
            embeds_list.append(output_embeds)
        ent_embeds = np.concatenate(embeds_list, axis=1)
        rd.save_embeddings(self.out_folder, self.kgs, ent_embeds, None, None, mapping_mat=None)

    def generate_input_batch(self, batch_size, neighbors1=None, neighbors2=None):
        if batch_size > len(self.sup_ent1):
            batch_size = len(self.sup_ent1)
        index = np.random.choice(len(self.sup_ent1), batch_size)
        pos_links = self.sup_links[index,]
        neg_links = list()
        if neighbors1 is None:
            # neg_ent1 = list()
            # neg_ent2 = list()
            # for i in range(self.args.neg_triple_num):
            #     neg_ent1.extend(random.sample(self.sup_ent1 + self.ref_ent1, batch_size))
            #     neg_ent2.extend(random.sample(self.sup_ent2 + self.ref_ent2, batch_size))
            # neg_links.extend([(neg_ent1[i], neg_ent2[i]) for i in range(len(neg_ent1))])

            neg_ent1 = list()
            neg_ent2 = list()
            for i in range(self.args.neg_triple_num):
                neg_ent1.extend(random.sample(self.kgs.kg1.entities_list, batch_size))
                neg_ent2.extend(random.sample(self.kgs.kg2.entities_list, batch_size))
            neg_links.extend([(neg_ent1[i], neg_ent2[i]) for i in range(len(neg_ent1))])
        else:
            for i in range(batch_size):
                e1 = pos_links[i, 0]
                candidates = random.sample(neighbors1.get(e1), self.args.neg_triple_num // 2)
                neg_links.extend([(e1, candidate) for candidate in candidates])
                e2 = pos_links[i, 1]
                candidates = random.sample(neighbors2.get(e2), self.args.neg_triple_num // 2)
                neg_links.extend([(candidate, e2) for candidate in candidates])

            neg_ent1 = list()
            neg_ent2 = list()
            for i in range(self.args.neg_triple_num):
                neg_ent1.extend(random.sample(self.kgs.kg1.entities_list, batch_size // 2))
                neg_ent2.extend(random.sample(self.kgs.kg2.entities_list, batch_size // 2))
            neg_links.extend([(neg_ent1[i], neg_ent2[i]) for i in range(len(neg_ent1))])

        neg_links = set(neg_links) - self.sup_links_set
        neg_links = neg_links - self.new_sup_links_set
        neg_links = np.array(list(neg_links))
        return pos_links, neg_links

    def generate_rel_batch(self):
        hs, rs, ts = list(), list(), list()
        for r, hts in self.rel_ht_dict.items():
            hts_batch = [random.choice(hts) for _ in range(self.rel_win_size)]
            for h, t in hts_batch:
                hs.append(h)
                ts.append(t)
                rs.append(r)
        return hs, rs, ts

    # def find_neighbors(self):
    #     if self.args.truncated_epsilon <= 0.0:
    #         return None, None
    #     start = time.time()
    #     output_embeds_list = self.output_embeds_list
    #     ents1 = self.sup_ent1 + self.ref_ent1
    #     ents2 = self.sup_ent2 + self.ref_ent2
    #     embeds1 = tf.nn.embedding_lookup(output_embeds_list[-1], ents1)
    #     embeds2 = tf.nn.embedding_lookup(output_embeds_list[-1], ents2)
    #     embeds1 = tf.nn.l2_normalize(embeds1, 1)
    #     embeds2 = tf.nn.l2_normalize(embeds2, 1)
    #     embeds1 = np.array(embeds1.eval(session=self.session))
    #     embeds2 = np.array(embeds2.eval(session=self.session))
    #     num = int((1 - self.args.truncated_epsilon) * len(ents1))
    #     print("neighbors num", num)
    #
    #     neighbors1 = generate_neighbours(embeds1, ents1, embeds2, ents2, num,
    #                                      threads_num=self.args.test_threads_num)
    #     neighbors2 = generate_neighbours(embeds2, ents2, embeds1, ents1, num,
    #                                      threads_num=self.args.test_threads_num)
    #     print('finding neighbors for sampling costs time: {:.4f}s'.format(time.time() - start))
    #     return neighbors1, neighbors2

    def lookup_3d_embeds(self, entities):
        input_embeds = self.init_embedding
        output_embeds_list = self.output_embeds_list
        res = []
        for output_embeds in [input_embeds] + output_embeds_list:
            res.append(tf.nn.l2_normalize(output_embeds, 1))
        embeds = tf.concat(res, axis=1)
        embeds = tf.nn.l2_normalize(embeds, 1)
        return tf.nn.embedding_lookup(embeds, entities)

    def lookup_embeds(self, entities):
        input_embeds = self.init_embedding
        output_embeds_list = self.output_embeds_list
        res = []
        for output_embeds in [input_embeds] + output_embeds_list:
            if entities is None:
                embeds1 = output_embeds
            else:
                embeds1 = tf.nn.embedding_lookup(output_embeds, entities)
            embeds1 = tf.nn.l2_normalize(embeds1, 1)
            res.append(embeds1)
        return tf.concat(res, axis=1)

    def lookup_last_embeds(self, entities):
        output_embeds = self.output_embeds_list[-1]
        if entities is None:
            embeds1 = output_embeds
        else:
            embeds1 = tf.nn.embedding_lookup(output_embeds, entities)
        embeds1 = tf.nn.l2_normalize(embeds1, 1)
        return embeds1

    def eval_embeddings(self, entity_list):
        embeds1 = self.lookup_embeds(entity_list)
        embeds1 = tf.nn.l2_normalize(embeds1, 1)
        return embeds1.eval(session=self.session)

    def generate_neighbors(self):
        t1 = time.time()
        assert 0.0 < self.args.truncated_epsilon < 1.0
        output_embeds_list = self.output_embeds_list

        total_ent_embeds = output_embeds_list[-1].eval(session=self.session)
        total_ent_embeds = preprocessing.normalize(total_ent_embeds)
        ents1 = self.sup_ent1
        ents2 = self.sup_ent2
        embeds1 = total_ent_embeds[np.array(ents1),]
        embeds2 = total_ent_embeds[np.array(ents2),]

        num1 = len(self.kgs.kg1.entities_list) // 2
        if len(self.kgs.kg1.entities_list) > 200000:
            num1 = len(self.kgs.kg1.entities_list) // 3
        kg1_random_ents = random.sample(self.kgs.kg1.entities_list, num1)
        random_embeds1 = total_ent_embeds[np.array(kg1_random_ents),]

        num2 = len(self.kgs.kg2.entities_list) // 2
        if len(self.kgs.kg2.entities_list) > 200000:
            num2 = len(self.kgs.kg2.entities_list) // 3
        kg2_random_ents = random.sample(self.kgs.kg2.entities_list, num2)
        random_embeds2 = total_ent_embeds[np.array(kg2_random_ents),]

        neighbors_num1 = int((1 - self.args.truncated_epsilon) * num1)
        neighbors_num2 = int((1 - self.args.truncated_epsilon) * num2)
        print("generating neighbors...", neighbors_num1, neighbors_num2)
        neighbors1 = generate_neighbours(embeds1, ents1, random_embeds2, kg2_random_ents, neighbors_num2,
                                         threads_num=self.args.test_threads_num)
        neighbors2 = generate_neighbours(embeds2, ents2, random_embeds1, kg1_random_ents, neighbors_num1,
                                         threads_num=self.args.test_threads_num)
        print("generating neighbors ({}, {}) costs {:.3f} s.".format(num1, num2, time.time() - t1))
        return neighbors1, neighbors2

    def valid_alignment(self, stop_metric):
        print("\nevaluating synthetic alignment...")
        embeds1, embeds2, mapping = self._eval_valid_embeddings()
        hits, mrr_12, sim_list = valid(embeds1, embeds2, mapping, self.args.top_k,
                                       self.args.test_threads_num, metric=self.args.eval_metric,
                                       normalize=self.args.eval_norm, csls_k=0, accurate=False)
        print()
        return hits[0] if stop_metric == 'hits1' else mrr_12

    def test(self, save=False):
        print("\ntesting synthetic alignment...")
        embeds1, embeds2, mapping = self._eval_test_embeddings()
        _, _, _, sim_list = test(embeds1, embeds2, mapping, self.args.top_k, self.args.test_threads_num,
                                 metric=self.args.eval_metric, normalize=self.args.eval_norm, csls_k=0, accurate=True)
        test(embeds1, embeds2, mapping, self.args.top_k, self.args.test_threads_num,
             metric=self.args.eval_metric, normalize=self.args.eval_norm, csls_k=self.args.csls, accurate=True)
        print()
        if self.args.detection_mode == "margin" or self.args.detection_mode == "open":
            self.two_step_evaluation_margin(self.kgs.test_linked_entities1,
                                            self.kgs.test_unlinked_entities1, is_test=True)

    def launch_distance_margin_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        training_data = self.kgs.train_unlinked_entities1
        batch_size = len(training_data) // triple_steps
        batch_size = int(min(batch_size, len(self.kgs.train_unlinked_entities1)) / 1.1)
        embeds = self.lookup_last_embeds(None).eval(session=self.session)
        mapping_mat = None
        embeds2 = embeds[np.array(self.kgs.kg2.entities_list),]
        steps_num = max(1, len(self.kgs.train_unlinked_entities1) // batch_size)
        for i in range(steps_num):
            batch_data1 = random.sample(training_data, batch_size)
            ent1 = [x[0] for x in batch_data1]
            embeds1 = embeds[np.array(ent1),]
            ent12 = search_kg1_to_kg2_1nn_neighbor(embeds1, embeds2, self.kgs.kg2.entities_list, mapping_mat, soft_nn=1)

            batch_loss, _ = self.session.run(fetches=[self.dis_loss, self.dis_optimizer],
                                             feed_dict={self.input_ents1: ent1,
                                                        self.input_ents2: ent12})
            epoch_loss += batch_loss
            trained_samples_num += len(batch_data1)
        print('epoch {}, margin loss: {:.8f}, cost time: {:.1f}s'.format(epoch, epoch_loss, time.time() - start))

    def launch_open_margin_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        training_data = self.kgs.train_unlinked_entities1
        batch_size = len(training_data) // triple_steps
        batch_size = int(min(batch_size, len(self.kgs.train_unlinked_entities1)) / 1.1)
        steps_num = max(1, len(self.kgs.train_unlinked_entities1) // batch_size)
        for i in range(steps_num):
            batch_data1 = random.sample(training_data, batch_size)
            ent1 = [x[0] for x in batch_data1]
            ents = []
            for ent in ent1:
                ent2 = random.sample(self.kgs.kg2.entities_list, self.args.num_random)
                ents.append(ent2)
            ents = np.array(ents)
            batch_loss, _ = self.session.run(fetches=[self.open_loss, self.open_optimizer],
                                             feed_dict={self.input_ents1: ent1,
                                                        self.input_ents2: ents})
            epoch_loss += batch_loss
            trained_samples_num += len(batch_data1)
        epoch_loss /= trained_samples_num
        print('epoch {}, open margin loss: {:.8f}, cost time: {:.1f}s'.format(epoch, epoch_loss,
                                                                              time.time() - start))

    def launch_dangling_detection_binary_cls_training_1epo(self, epoch, triple_steps):
        # labels: matchable 0, dangling 1
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0

        unlinked_entities1 = self.kgs.train_unlinked_entities1
        linked_entities1 = self.kgs.train_linked_entities1
        all_entities = unlinked_entities1 + linked_entities1
        unlinked_labels = np.ones(len(unlinked_entities1))
        linked_labels = np.zeros(len(linked_entities1))
        labels = np.hstack((unlinked_labels, linked_labels))
        print(f'# linked_ent1: {len(linked_entities1)}, unlinked_ent1: {len(unlinked_entities1)}')

        batch_size = len(all_entities) // triple_steps
        embeds = self.lookup_last_embeds(None).eval(session=self.session)
        mapping_mat = None
        steps_num = max(1, len(all_entities) // batch_size)
        ents2_candidates = list(self.kgs.kg2.entities_set - set(self.kgs.train_entities2))
        embeds2 = embeds[ents2_candidates,]

        # ipdb.set_trace()
        perm = np.random.permutation(len(all_entities))
        all_entities = np.array(all_entities, dtype=np.int32)[perm]
        labels = labels[perm]
        epoch_correct = 0
        # sample training data, sample batch_size/2 unlinked ent and batch_size/2 linked ent
        tot_search_time = 0
        for i in range(steps_num):
            t1 = time.time()
            batch_ent1 = all_entities[i * batch_size:(i + 1) * batch_size]
            batch_labels = labels[i * batch_size:(i + 1) * batch_size]
            ent1 = [x[0] for x in batch_ent1]
            ent_embeds1 = embeds[np.array(ent1),]
            ent12, s2t_sim = search_kg1_to_kg2_ordered_all_nns(ent_embeds1, embeds2, ents2_candidates, mapping_mat,
                                                               return_all_sim=True,
                                                               soft_nn=self.args.s2t_topk)
            retrieved_target_embeds = embeds[np.array(ent12, dtype=np.int32),]  # shape (batch_size, s2t_topk, d_emb)
            # ipdb.set_trace()
            reversed_ent21, t2s_sim = search_kg2_to_kg1_ordered_all_nns(retrieved_target_embeds, ent_embeds1, ent1,
                                                                        mapping_mat, return_all_sim=True,
                                                                        soft_nn=self.args.t2s_topk)
            t2s_sim = np.array(t2s_sim)
            tot_search_time += time.time() - t1
            # print(f'len(reversed_ent21): {len(reversed_ent21)}, t2s_sim.shape: {t2s_sim.shape}')
            if self.args.concat_forward_reverse:
                s2t_sim = np.array(s2t_sim)
                batch_loss, dangling_bin_cls_output, _ = self.session.run(
                    fetches=[self.bin_cls_loss, self.dangling_bin_cls_output, self.bin_cls_optimizer],
                    feed_dict={self.topk_scores: t2s_sim,
                               self.s2t_scores: s2t_sim,
                               self.bin_cls_labels: batch_labels,
                               self.is_training: True})
            else:
                batch_loss, dangling_bin_cls_output, _ = self.session.run(
                    fetches=[self.bin_cls_loss, self.dangling_bin_cls_output, self.bin_cls_optimizer],
                    feed_dict={self.topk_scores: t2s_sim,
                               self.bin_cls_labels: batch_labels,
                               self.is_training: True})
            dangling_bin_cls_output = (dangling_bin_cls_output > 0.5).astype(np.int32)
            batch_num_correct = (dangling_bin_cls_output == batch_labels).astype(np.int32).sum()
            epoch_correct += batch_num_correct
            epoch_loss += batch_loss
            trained_samples_num += len(batch_ent1)
        epoch_loss /= trained_samples_num
        print('epoch {}, dangling detection binary cls loss: {:.8f}, detection acc: {:.4f}, cost time: {:.1f}s'.format(
            epoch, epoch_loss,
            epoch_correct / len(labels), time.time() - start))

    def launch_NCA_training_1epo(self, epoch, triple_steps):
        start = time.time()
        epoch_loss = 0
        trained_samples_num = 0
        batch_size = len(self.kgs.train_links) // triple_steps

        for i in range(triple_steps):
            if batch_size > len(self.kgs.train_links):
                batch_size = len(self.kgs.train_links) # a quick fix, for full batch.
            links_batch = random.sample(self.kgs.train_links, batch_size)
            ent1 = [x[0] for x in links_batch]
            ent2 = [x[1] for x in links_batch]
            batch_loss, _, S_, pos_scores, neg_scores, in_pos_scores, embed1, embed2 = self.session.run(fetches=[self.NCA_loss, self.NCA_optimizer, self.S_, self.pos_scores, self.neg_scores, self.in_pos_scores, self.NCA_embed1, self.NCA_embed2],
                                             feed_dict={self.NCA_input_ents1: ent1,
                                                        self.NCA_input_ents2: ent2,
                                                        self.alpha: self.args.NCA_alpha,
                                                        self.beta: self.args.NCA_beta,
                                                        self.ep: 0.0})
            # S_ = self.S_.eval(session=self.session)
            # print(f'S_: {S_}\n#nan in S_: {np.isnan(S_).sum()} ')
            # print(f'pos_scores: {pos_scores}\n#nan in pos_scores: {np.isnan(pos_scores).sum()}')
            # print(f'neg_scores: {neg_scores}\n#nan in neg_scores: {np.isnan(neg_scores).sum()}')
            # print(f'in_pos_scores: {in_pos_scores}\n#nan in in_pos_scores: {np.isnan(in_pos_scores).sum()}')
            # print(f'embed1: {embed1}')
            # print(f'embed2: {embed2}')
            # ipdb.set_trace()
            epoch_loss += batch_loss
            trained_samples_num += len(links_batch)
        epoch_loss /= trained_samples_num
        print('epoch {}, avg. NCA loss: {:.8f}, cost time: {:.1}s'.format(epoch, epoch_loss, time.time() - start))

    def run(self):
        print('start training...')
        # ipdb.set_trace()
        steps = max(1, len(self.sup_ent2) // self.args.batch_size)
        neighbors1, neighbors2 = None, None
        if steps == 0:
            steps = 1
        for epoch in range(1, self.args.max_epoch + 1):
            start = time.time()
            epoch_loss = 0.0
            with tf.device('/cpu:0'):
                for step in range(steps):
                    self.pos_link_batch, self.neg_link_batch = self.generate_input_batch(self.args.batch_size,
                                                                                         neighbors1=neighbors1,
                                                                                         neighbors2=neighbors2)
                    feed_dict = {self.pos_links: self.pos_link_batch,
                                 self.neg_links: self.neg_link_batch}
                    if self.args.rel_param > 0.0:
                        hs, _, ts = self.generate_rel_batch()
                        feed_dict = {self.pos_links: self.pos_link_batch,
                                     self.neg_links: self.neg_link_batch,
                                     self.hs: hs, self.ts:ts}

                    fetches = {"loss": self.loss, "optimizer": self.optimizer}
                    results = self.session.run(fetches=fetches, feed_dict=feed_dict)
                    batch_loss = results["loss"]
                    epoch_loss += batch_loss

                print('epoch {}, loss: {:.8f}, cost time: {:.4f}s'.format(epoch, epoch_loss, time.time() - start))

            if self.args.use_NCA_loss:
                self.launch_NCA_training_1epo(epoch, steps)

            if self.args.detection_mode == "margin":
                self.launch_distance_margin_training_1epo(epoch, steps)
            elif self.args.detection_mode == "open":
                self.launch_open_margin_training_1epo(epoch, steps)

            if self.args.use_dangling_bin_cls:
                self.launch_dangling_detection_binary_cls_training_1epo(epoch, steps)

            if epoch % self.args.eval_freq == 0 and epoch >= self.args.start_valid:
                # validation via synthetic alignment
                flag = self.valid_alignment(self.args.stop_metric)
                # self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                # validation via two-step alignment
                if self.args.detection_mode == "margin" or self.args.detection_mode == "open":
                    flag = self.two_step_evaluation_margin(self.kgs.valid_linked_entities1,
                                                           self.kgs.valid_unlinked_entities1)
                    self.flag1, self.flag2, self.early_stop = early_stop(self.flag1, self.flag2, flag)
                if self.early_stop:
                    print("\n == training stop == \n")
                    break
                if epoch < self.args.max_epoch:
                    neighbors1, neighbors2 = self.generate_neighbors()

    def save_model(self):
        print('saving the model...')
        saver = tf.train.Saver()
        save_path = saver.save(self.session, self.out_folder+f's{self.args.s2t_topk}_t{self.args.t2s_topk}_model.ckpt')
        print(f'Model saved in {save_path}')

        