import tensorflow as tf
import sys
import math
import numpy as np


class PrivNet(object):
    def __init__(self,user_count,item_count_tgt,item_count_src,config):
        # Hyperparameters
        self.config = config
        self.hidden_units = config['hidden_units']
        self.clip_norm = config['clip_norm']
        self.clip_norm_attack = config['clip_norm_attack']
        self.fc_layer = config['fc_layer']
        self.num_classes_gender = config['num_classes_gender']
        self.num_classes_age = config['num_classes_age']
        self.num_classes_occupation = config['num_classes_occupation']
        self.adversary_weight = config['adversary_weight']
        self.global_step = tf.Variable(0, trainable=False, name='global_step')
        self.global_epoch = tf.Variable(0, trainable=False, name='global_epoch')
        # Target domain
        self.u_tgt = tf.placeholder(tf.int32, [None,], name="input_u_tgt")  # B
        self.hist_i_tgt = tf.placeholder(tf.int32, [None, None], name="input_hist_i_tgt")  # T
        self.i_tgt = tf.placeholder(tf.int32, [None,], name="input_i_tgt")
        self.y_tgt = tf.placeholder(tf.float32, [None,], name="input_y_tgt")
        self.sl_tgt = tf.placeholder(tf.int32, [None,], name="input_sl_tgt")  # len(history_i)
        # testing & evaluation
        self.u_test = tf.placeholder(tf.int32, [None,], name="input_u_test")
        self.hist_i_test = tf.placeholder(tf.int32, [None, None], name="input_hist_i_test")
        self.i_test = tf.placeholder(tf.int32, [None,], name="input_i_test")
        self.sl_test = tf.placeholder(tf.int32, [None,], name="input_sl_test")
        self.j_test = tf.placeholder(tf.int32, [None,], name="input_j_test")  # used when testing, only target domain
        self.j_parent = tf.placeholder(tf.int32, [None,], name="input_j_parent")  # parent of items
        # Source domain
        self.u_src = tf.placeholder(tf.int32, [None,], name="input_u_src")
        self.hist_i_src = tf.placeholder(tf.int32, [None, None], name="input_hist_i_src")
        self.i_src = tf.placeholder(tf.int32, [None,], name="input_i_src")
        self.y_src = tf.placeholder(tf.float32, [None,], name="input_y_src")
        self.sl_src = tf.placeholder(tf.int32, [None,], name="input_sl_src")
        # Translator
        self.u_trans = tf.placeholder(tf.int32, [None,], name="input_u_trans")
        # src rep
        self.hist_i_src_trans = tf.placeholder(tf.int32, [None, None], name="input_hist_i_src_trans")
        self.i_src_trans = tf.placeholder(tf.int32, [None,], name="input_i_src_trans")
        self.sl_src_trans = tf.placeholder(tf.int32, [None,], name="input_sl_src_trans")
        self.y_src_trans = tf.placeholder(tf.float32, [None,], name="input_y_src_trans")
        # tgt rep
        self.hist_i_tgt_trans = tf.placeholder(tf.int32, [None, None], name="input_hist_i_tgt_trans")
        self.i_tgt_trans = tf.placeholder(tf.int32, [None,], name="input_i_tgt_trans")
        self.sl_tgt_trans = tf.placeholder(tf.int32, [None,], name="input_sl_tgt_trans")
        self.y_tgt_trans = tf.placeholder(tf.float32, [None,], name="input_y_tgt_trans")
        # Attacker models: gender & age & occupation
        self.y_gender = tf.placeholder(tf.float32, [None, self.num_classes_gender], name="input_y_gender")
        self.y_age = tf.placeholder(tf.float32, [None, self.num_classes_age], name="input_y_age")
        self.y_occupation = tf.placeholder(tf.float32, [None, self.num_classes_occupation], name="input_y_occupation")
        # --- Parameters ---
        # user_emb_weight = tf.get_variable('user_emb_weight', [user_count, self.hidden_units])  # glorot uniform init
        self.item_emb_weight_tgt = tf.get_variable('item_emb_weight_tgt', [item_count_tgt, self.hidden_units])
        self.item_emb_weight_src = tf.get_variable('item_emb_weight_src', [item_count_src, self.hidden_units])
        # Neural CF: src
        self.f1_src_weight = tf.get_variable('f1_src_weight', [int(2*self.hidden_units), self.fc_layer])
        self.f1_src_bias = tf.get_variable('f1_src_bias', [self.fc_layer])
        self.f3_src_weight = tf.get_variable('f3_src_weight', [self.fc_layer, 1])
        self.f3_src_bias = tf.get_variable('f3_src_bias', [1])
        self.item_bias_src = tf.get_variable('item_bias_src', [item_count_src], initializer=tf.constant_initializer(0.0))
        # Neural CF: tgt
        self.f1_tgt_weight = tf.get_variable('f1_tgt_weight', [int(2*self.hidden_units), self.fc_layer])
        self.f1_tgt_bias = tf.get_variable('f1_tgt_bias', [self.fc_layer])
        self.f3_tgt_weight = tf.get_variable('f3_tgt_weight', [self.fc_layer, 1])
        self.f3_tgt_bias = tf.get_variable('f3_tgt_bias', [1])
        self.item_bias_tgt = tf.get_variable('item_bias_tgt', [item_count_tgt], initializer=tf.constant_initializer(0.0))
        # Transfer model
        self.transfer_matrix_s2t = tf.get_variable('transfer_matrix_s2t', [self.hidden_units, self.hidden_units])
        self.transfer_matrix_t2s = tf.get_variable('transfer_matrix_t2s', [self.hidden_units, self.hidden_units])
        self.transfer_matrix_t_cf = tf.get_variable('transfer_matrix_t_cf', [int(3*self.hidden_units), int(2*self.hidden_units)])
        self.transfer_matrix_s_cf = tf.get_variable('transfer_matrix_s_cf', [int(3*self.hidden_units), int(2*self.hidden_units)])
        # Privacy Attacker Models: gender & age & occupation
        # Multitask learning: sharing the lower layers, only the softmax layers are different
        self.f1_attack_weight = tf.get_variable('f1_attack_weight', [self.hidden_units, self.fc_layer])
        self.f1_attack_bias = tf.get_variable('f1_attack_bias', [self.fc_layer])
        # softmax output layers: specific to their own number of classes
        self.f3_gender_weight = tf.get_variable('f3_gender_weight', [self.fc_layer, self.num_classes_gender])
        self.f3_gender_bias = tf.get_variable('f3_gender_bias', [self.num_classes_gender])
        self.f3_age_weight = tf.get_variable('f3_age_weight', [self.fc_layer, self.num_classes_age])
        self.f3_age_bias = tf.get_variable('f3_age_bias', [self.num_classes_age])
        self.f3_occupation_weight = tf.get_variable('f3_occupation_weight', [self.fc_layer, self.num_classes_occupation])
        self.f3_occupation_bias = tf.get_variable('f3_occupation_bias', [self.num_classes_occupation])

        # ---Source network: vanilla---
        # history
        emb_h_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_src)
        h_emb_src = emb_h_src
        # item encoder
        emb_i_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_src)  # [B,L,H] where L=len(words_i)
        i_rep_src = emb_i_src
        # user encoder
        # u_hist_i_src = self.attention_history_source(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = attention_dot(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = tf.reshape(u_hist_i_src, [-1, self.hidden_units])
        u_rep_i_src = u_hist_i_src
        # fcn
        din_i_src = tf.concat([u_rep_i_src, i_rep_src], axis=1)
        d_layer_1_i_src = tf.add(tf.matmul(din_i_src,self.f1_src_weight), self.f1_src_bias)
        d_layer_1_i_src = tf.nn.sigmoid(d_layer_1_i_src)
        d_layer_3_i_src = tf.add(tf.matmul(d_layer_1_i_src,self.f3_src_weight), self.f3_src_bias)
        d_layer_3_i_src = tf.reshape(d_layer_3_i_src, [-1])
        # output
        i_b_src = tf.gather(self.item_bias_src, self.i_src)
        self.logits_src = i_b_src + d_layer_3_i_src
        self.score_i_src = tf.sigmoid(self.logits_src)
        self.score_i_src = tf.reshape(self.score_i_src, [-1,1])
        # optimization: update source base network only
        self.loss_src = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits_src, labels=self.y_src))
        trainable_params_src = [self.item_emb_weight_src,
                                self.f1_src_weight, self.f1_src_bias,
                                self.f3_src_weight, self.f3_src_bias, self.item_bias_src]
        self.opt_src = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        gradients_src = tf.gradients(self.loss_src, trainable_params_src)
        clip_gradients_src, _ = tf.clip_by_global_norm(gradients_src, self.clip_norm)
        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)
        self.train_op_src = self.opt_src.apply_gradients(zip(clip_gradients_src,trainable_params_src), global_step=self.global_step)

        # ---An adversarial source network---
        # self.loss_attack_src + self.adversary_weight * self.loss_attack_fake
        # ---Source model---
        # history
        emb_h_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_src)
        h_emb_src = emb_h_src
        # item encoder
        emb_i_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_src)  # [B,L,H] where L=len(words_i)
        i_rep_src = emb_i_src
        # user encoder
        # u_hist_i_src = self.attention_history_source(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = attention_dot(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = tf.reshape(u_hist_i_src, [-1, self.hidden_units])
        u_rep_i_src = u_hist_i_src
        # fcn
        din_i_src = tf.concat([u_rep_i_src, i_rep_src], axis=1)
        d_layer_1_i_src = tf.add(tf.matmul(din_i_src,self.f1_src_weight), self.f1_src_bias)
        d_layer_1_i_src = tf.nn.sigmoid(d_layer_1_i_src)
        d_layer_3_i_src = tf.add(tf.matmul(d_layer_1_i_src,self.f3_src_weight), self.f3_src_bias)
        d_layer_3_i_src = tf.reshape(d_layer_3_i_src, [-1])
        # output
        i_b_src = tf.gather(self.item_bias_src, self.i_src)
        self.logits_src = i_b_src + d_layer_3_i_src
        self.score_i_src = tf.sigmoid(self.logits_src)
        self.score_i_src = tf.reshape(self.score_i_src, [-1,1])
        # ---Attackers-Simulator---
        # this part is shared from Source network:
        # history
        # item encoder
        # user encoder
        # fcn: multitask learning all private attackers model
        din_i_attack = u_rep_i_src
        d_layer_1_i_attack = tf.add(tf.matmul(din_i_attack,self.f1_attack_weight), self.f1_attack_bias)
        d_layer_1_i_attack = tf.nn.sigmoid(d_layer_1_i_attack)
        # Gender
        d_layer_3_i_gender = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_gender_weight), self.f3_gender_bias)
        d_layer_3_i_gender = tf.reshape(d_layer_3_i_gender, [-1, self.num_classes_gender])
        # Age
        d_layer_3_i_age = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_age_weight), self.f3_age_bias)
        d_layer_3_i_age = tf.reshape(d_layer_3_i_age, [-1, self.num_classes_age])
        # Occupation
        d_layer_3_i_occupation = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_occupation_weight), self.f3_occupation_bias)
        d_layer_3_i_occupation = tf.reshape(d_layer_3_i_occupation, [-1, self.num_classes_occupation])
        # output
        self.logits_gender = d_layer_3_i_gender
        self.score_gender = tf.nn.softmax(self.logits_gender)
        self.logits_age = d_layer_3_i_age
        self.score_age = tf.nn.softmax(self.logits_age)
        self.logits_occupation = d_layer_3_i_occupation
        self.score_occupation = tf.nn.softmax(self.logits_occupation)
        # optimization
        # source
        self.loss_attack_src = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits_src, labels=self.y_src))
        trainable_params_attack_fake_src = [self.item_emb_weight_src,
                                self.f1_src_weight, self.f1_src_bias,
                                self.f3_src_weight, self.f3_src_bias, self.item_bias_src]
        # attack
        self.loss_gender_fake = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_gender, labels=self.y_gender))
        self.loss_age_fake = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_age, labels=self.y_age))
        self.loss_occupation_fake = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_occupation, labels=self.y_occupation))
        self.loss_attack_fake = tf.reduce_mean([self.loss_gender_fake, self.loss_age_fake, self.loss_occupation_fake])
        # DO NOT update attacker's parameters
        #trainable_params_attack_fake_src.extend([self.f1_attack_weight, self.f1_attack_bias,
        #                                self.f3_gender_weight, self.f3_gender_bias,  # gender softmax output
        #                                self.f3_age_weight, self.f3_age_bias,  # age softmax output
        #                                self.f3_occupation_weight, self.f3_occupation_bias  # occupation softmax output
        #                                ])

        # jointly: src + \lambda * privacy
        self.loss_attack_fake_src = self.loss_attack_src + self.adversary_weight*self.loss_attack_fake
        gradients_attack_fake_src = tf.gradients(self.loss_attack_fake_src, trainable_params_attack_fake_src)
        clip_gradients_attack_fake_src, _ = tf.clip_by_global_norm(gradients_attack_fake_src, self.clip_norm)
        self.opt_attack_fake_src = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        self.train_op_attack_fake_src = self.opt_attack_fake_src.apply_gradients(
                zip(clip_gradients_attack_fake_src, trainable_params_attack_fake_src), global_step=self.global_step)
        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)

        # ---attackers real training---
        # generate transferred knowledge through the source network
        # history
        emb_h_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_src)
        h_emb_src = emb_h_src
        # item encoder
        emb_i_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_src)  # [B,L,H] where L=len(words_i)
        i_rep_src = emb_i_src
        # user encoder
        # u_hist_i_src = self.attention_history_source(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = attention_dot(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = tf.reshape(u_hist_i_src, [-1, self.hidden_units])
        u_rep_i_src = u_hist_i_src
        # fcn
        din_i_src = tf.concat([u_rep_i_src, i_rep_src], axis=1)
        d_layer_1_i_src = tf.add(tf.matmul(din_i_src,self.f1_src_weight), self.f1_src_bias)
        d_layer_1_i_src = tf.nn.sigmoid(d_layer_1_i_src)
        d_layer_3_i_src = tf.add(tf.matmul(d_layer_1_i_src,self.f3_src_weight), self.f3_src_bias)
        d_layer_3_i_src = tf.reshape(d_layer_3_i_src, [-1])
        # output
        i_b_src = tf.gather(self.item_bias_src, self.i_src)
        self.logits_src = i_b_src + d_layer_3_i_src
        self.score_i_src = tf.sigmoid(self.logits_src)
        self.score_i_src = tf.reshape(self.score_i_src, [-1,1])
        # ---Attackers-Simulator---
        # this part is shared from Source network:
        # this part is shared from Source network:
        din_i_attack = u_rep_i_src
        d_layer_1_i_attack = tf.add(tf.matmul(din_i_attack,self.f1_attack_weight), self.f1_attack_bias)
        d_layer_1_i_attack = tf.nn.sigmoid(d_layer_1_i_attack)
        # Gender
        d_layer_3_i_gender = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_gender_weight), self.f3_gender_bias)
        d_layer_3_i_gender = tf.reshape(d_layer_3_i_gender, [-1, self.num_classes_gender])
        # Age
        d_layer_3_i_age = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_age_weight), self.f3_age_bias)
        d_layer_3_i_age = tf.reshape(d_layer_3_i_age, [-1, self.num_classes_age])
        # Occupation
        d_layer_3_i_occupation = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_occupation_weight), self.f3_occupation_bias)
        d_layer_3_i_occupation = tf.reshape(d_layer_3_i_occupation, [-1, self.num_classes_occupation])
        # output
        self.logits_gender = d_layer_3_i_gender
        self.score_gender = tf.nn.softmax(self.logits_gender)
        self.logits_age = d_layer_3_i_age
        self.score_age = tf.nn.softmax(self.logits_age)
        self.logits_occupation = d_layer_3_i_occupation
        self.score_occupation = tf.nn.softmax(self.logits_occupation)
        # optimization
        # attack
        self.loss_gender_real = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_gender, labels=self.y_gender))
        self.loss_age_real = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_age, labels=self.y_age))
        self.loss_occupation_real = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_occupation, labels=self.y_occupation))
        self.loss_attack_real = tf.reduce_mean([self.loss_gender_real, self.loss_age_real, self.loss_occupation_real])
        # DO NOT update attacker's parameters
        trainable_params_attack_real = [self.f1_attack_weight, self.f1_attack_bias,
                                        self.f3_gender_weight, self.f3_gender_bias,  # gender softmax output
                                        self.f3_age_weight, self.f3_age_bias,  # age softmax output
                                        self.f3_occupation_weight, self.f3_occupation_bias  # occupation softmax output
                                       ]
        gradients_attack_real = tf.gradients(self.loss_attack_real, trainable_params_attack_real)
        clip_gradients_attack_real, _ = tf.clip_by_global_norm(gradients_attack_real, self.clip_norm)
        self.opt_attack_real = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        self.train_op_attack_real = self.opt_attack_real.apply_gradients(
                zip(clip_gradients_attack_real, trainable_params_attack_real), global_step=self.global_step)

        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)

        # ---Attackers-Real: used for testing only, since it is a copy of the simulated attacker---
        # history
        emb_h_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_src)
        h_emb_src = emb_h_src
        # item encoder
        emb_i_src = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_src)
        i_rep_src = emb_i_src
        # user encoder
        # u_hist_i_src = self.attention_history_source(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = attention_dot(i_rep_src, h_emb_src, self.sl_src)  # query,keys,len
        u_hist_i_src = tf.reshape(u_hist_i_src, [-1, self.hidden_units])
        u_rep_i_src = u_hist_i_src
        # fcn: multitask learning, sharing lower layers
        din_i_attack = u_rep_i_src
        d_layer_1_i_attack = tf.add(tf.matmul(din_i_attack,self.f1_attack_weight), self.f1_attack_bias)
        d_layer_1_i_attack = tf.nn.sigmoid(d_layer_1_i_attack)
        # Gender
        d_layer_3_i_gender = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_gender_weight), self.f3_gender_bias)
        d_layer_3_i_gender = tf.reshape(d_layer_3_i_gender, [-1, self.num_classes_gender])
        # Age
        d_layer_3_i_age = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_age_weight), self.f3_age_bias)
        d_layer_3_i_age = tf.reshape(d_layer_3_i_age, [-1, self.num_classes_age])
        # Occupation
        d_layer_3_i_occupation = tf.add(tf.matmul(d_layer_1_i_attack,self.f3_occupation_weight), self.f3_occupation_bias)
        d_layer_3_i_occupation = tf.reshape(d_layer_3_i_occupation, [-1, self.num_classes_occupation])
        # output
        self.logits_gender = d_layer_3_i_gender
        self.score_gender = tf.nn.softmax(self.logits_gender)
        self.logits_age = d_layer_3_i_age
        self.score_age = tf.nn.softmax(self.logits_age)
        self.logits_occupation = d_layer_3_i_occupation
        self.score_occupation = tf.nn.softmax(self.logits_occupation)
        self.loss_gender = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_gender, labels=self.y_gender))
        self.loss_age = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_age, labels=self.y_age))
        self.loss_occupation = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits_occupation, labels=self.y_occupation))

        # ---Target model: train---
        # history
        emb_h_tgt = tf.nn.embedding_lookup(self.item_emb_weight_tgt, self.hist_i_tgt)
        h_emb_tgt = emb_h_tgt
        # item encoder
        emb_i_tgt = tf.nn.embedding_lookup(self.item_emb_weight_tgt, self.i_tgt)  # [B,L,H] where L=len(words_i)
        i_rep_tgt = emb_i_tgt
        # user encoder
        # u_hist_i_tgt = self.attention_history_target(i_rep_tgt, h_emb_tgt, self.sl_tgt)  # query,keys,len
        u_hist_i_tgt = attention_dot(i_rep_tgt, h_emb_tgt, self.sl_tgt)  # query,keys,len
        u_hist_i_tgt = tf.reshape(u_hist_i_tgt, [-1, self.hidden_units])
        u_rep_i_tgt = u_hist_i_tgt
        # fcn
        din_i_tgt = tf.concat([u_rep_i_tgt, i_rep_tgt], axis=1)  # can consider the user_id information, e.g., global
        d_layer_1_i_tgt = tf.add(tf.matmul(din_i_tgt,self.f1_tgt_weight), self.f1_tgt_bias)
        d_layer_1_i_tgt = tf.nn.sigmoid(d_layer_1_i_tgt)
        d_layer_3_i_tgt = tf.add(tf.matmul(d_layer_1_i_tgt,self.f3_tgt_weight), self.f3_tgt_bias)
        d_layer_3_i_tgt = tf.reshape(d_layer_3_i_tgt, [-1])
        # output
        i_b_tgt = tf.gather(self.item_bias_tgt, self.i_tgt)
        self.logits_tgt = i_b_tgt + d_layer_3_i_tgt
        self.score_i_tgt = tf.sigmoid(self.logits_tgt)
        self.score_i_tgt = tf.reshape(self.score_i_tgt, [-1,1])
        # optimization: update target base network only
        self.loss_tgt = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits_tgt, labels=self.y_tgt))
        trainable_params_tgt = [self.item_emb_weight_tgt,
                                self.f1_tgt_weight, self.f1_tgt_bias,
                                self.f3_tgt_weight, self.f3_tgt_bias, self.item_bias_tgt]
        self.opt_tgt = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        gradients_tgt = tf.gradients(self.loss_tgt, trainable_params_tgt)
        clip_gradients_tgt, _ = tf.clip_by_global_norm(gradients_tgt, self.clip_norm)
        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)
        self.train_op_tgt = self.opt_tgt.apply_gradients(zip(clip_gradients_tgt,trainable_params_tgt),
                                                         global_step=self.global_step)
        # --Target model: test. target cold-start users, so their SOURCE representation used--
        # users,hist_i,items,sl, items_tgt, j_parent
        # user rep: computed from SOURCE domain
        emb_h_test = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_test)  # SOURCE: hist_i
        h_emb_test = emb_h_test
        emb_i_test = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_test)  # SOURCE: i_src
        i_rep_test = emb_i_test
        u_hist_i_test = attention_dot(i_rep_test, h_emb_test, self.sl_test)  # SOURCE
        u_hist_i_test = tf.reshape(u_hist_i_test, [-1, self.hidden_units])
        u_rep_i_test = u_hist_i_test  # their source rep
        # item rep: Target domain, user cold-start recommendation
        emb_j_test = tf.nn.embedding_lookup(self.item_emb_weight_tgt, self.j_test)  # predicted on target item
        j_rep_test = emb_j_test
        # Translator: source -> target
        rep_src2tgt_test = tf.matmul(u_rep_i_test, self.transfer_matrix_s2t)
        din_j_test = tf.concat([rep_src2tgt_test, j_rep_test], axis=-1)
        # neural CF: target network
        d_layer_1_j_test = tf.add(tf.matmul(din_j_test,self.f1_tgt_weight), self.f1_tgt_bias)
        d_layer_1_j_test = tf.nn.sigmoid(d_layer_1_j_test)
        d_layer_3_j_test = tf.add(tf.matmul(d_layer_1_j_test,self.f3_tgt_weight), self.f3_tgt_bias)
        d_layer_3_j_test = tf.reshape(d_layer_3_j_test, [-1])
        # preference score
        j_b_test = tf.gather(self.item_bias_tgt, self.j_test)
        self.score_j_test = tf.sigmoid(j_b_test + d_layer_3_j_test)

        # --- Translator: NOT jointly training, instead: source -> target, & target -> source ---
        # source user rep
        emb_h_src_trans = tf.nn.embedding_lookup(self.item_emb_weight_src, self.hist_i_src_trans)
        h_emb_src_trans = emb_h_src_trans
        emb_i_src_trans = tf.nn.embedding_lookup(self.item_emb_weight_src, self.i_src_trans)
        i_rep_src_trans = emb_i_src_trans
        # u_hist_i_src_trans = self.attention_history_source(i_rep_src_trans, h_emb_src_trans, self.sl_src_trans)
        u_hist_i_src_trans = attention_dot(i_rep_src_trans, h_emb_src_trans, self.sl_src_trans)
        u_hist_i_src_trans = tf.reshape(u_hist_i_src_trans, [-1, self.hidden_units])
        u_rep_i_src_trans = u_hist_i_src_trans
        # target user rep
        emb_h_tgt_trans = tf.nn.embedding_lookup(self.item_emb_weight_tgt,self.hist_i_tgt_trans)
        h_emb_tgt_trans = emb_h_tgt_trans
        emb_i_tgt_trans = tf.nn.embedding_lookup(self.item_emb_weight_tgt, self.i_tgt_trans)
        i_rep_tgt_trans = emb_i_tgt_trans
        # u_hist_i_tgt_trans = self.attention_history_target(i_rep_tgt_trans, h_emb_tgt_trans, self.sl_tgt_trans)
        u_hist_i_tgt_trans = attention_dot(i_rep_tgt_trans, h_emb_tgt_trans, self.sl_tgt_trans)
        u_hist_i_tgt_trans = tf.reshape(u_hist_i_tgt_trans, [-1, self.hidden_units])
        u_rep_i_tgt_trans = u_hist_i_tgt_trans
        # 1. Translator: training Target with knowledge from
        rep_src2tgt_trans = tf.matmul(u_rep_i_src_trans, self.transfer_matrix_s2t)
        # fcn
        din_i_tgt_trans = tf.concat([u_rep_i_tgt_trans, i_rep_tgt_trans, rep_src2tgt_trans], axis=1)
        # match the dimension
        din_i_tgt_trans = tf.matmul(din_i_tgt_trans, self.transfer_matrix_t_cf)
        d_layer_1_i_tgt_trans = tf.add(tf.matmul(din_i_tgt_trans,self.f1_tgt_weight), self.f1_tgt_bias)
        d_layer_1_i_tgt_trans = tf.nn.sigmoid(d_layer_1_i_tgt_trans)
        d_layer_3_i_tgt_trans = tf.add(tf.matmul(d_layer_1_i_tgt_trans,self.f3_tgt_weight), self.f3_tgt_bias)
        d_layer_3_i_tgt_trans = tf.reshape(d_layer_3_i_tgt_trans, [-1])
        # output
        i_b_tgt_trans = tf.gather(self.item_bias_tgt, self.i_tgt_trans)
        self.logits_tgt_trans = i_b_tgt_trans + d_layer_3_i_tgt_trans
        self.score_i_tgt_trans = tf.sigmoid(self.logits_tgt_trans)
        self.score_i_tgt_trans = tf.reshape(self.score_i_tgt_trans, [-1,1])
        # 2. Translator: : training Source with knowledge from target
        rep_tgt2src_trans = tf.matmul(u_rep_i_tgt_trans, self.transfer_matrix_t2s)
        # fcn
        din_i_src_trans = tf.concat([u_rep_i_src_trans, i_rep_src_trans, rep_tgt2src_trans], axis=1)
        # match the dimension
        din_i_src_trans = tf.matmul(din_i_src_trans, self.transfer_matrix_s_cf)
        d_layer_1_i_src_trans = tf.add(tf.matmul(din_i_src_trans,self.f1_src_weight), self.f1_src_bias)
        d_layer_1_i_src_trans = tf.nn.sigmoid(d_layer_1_i_src_trans)
        d_layer_3_i_src_trans = tf.add(tf.matmul(d_layer_1_i_src_trans,self.f3_src_weight), self.f3_src_bias)
        d_layer_3_i_src_trans = tf.reshape(d_layer_3_i_src_trans, [-1])
        # output
        i_b_src_trans = tf.gather(self.item_bias_src, self.i_src_trans)
        self.logits_src_trans = i_b_src_trans + d_layer_3_i_src_trans
        self.score_i_src_trans = tf.sigmoid(self.logits_src_trans)
        self.score_i_src_trans = tf.reshape(self.score_i_src_trans, [-1,1])
        # Optimization:
        self.loss_tgt_trans = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits_tgt_trans, labels=self.y_tgt_trans))
        # Cannot update the other's network, update own network only
        trainable_params_tgt_trans = [self.item_emb_weight_tgt,
                                self.f1_tgt_weight, self.f1_tgt_bias,
                                self.f3_tgt_weight, self.f3_tgt_bias, self.item_bias_tgt]
        trainable_params_tgt_trans.extend([self.transfer_matrix_s2t, self.transfer_matrix_t_cf])
        self.opt_tgt_trans = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        gradients_tgt_trans = tf.gradients(self.loss_tgt_trans, trainable_params_tgt_trans)
        clip_gradients_tgt_trans, _ = tf.clip_by_global_norm(gradients_tgt_trans, self.clip_norm)
        self.train_op_tgt_trans = self.opt_tgt_trans.apply_gradients(zip(clip_gradients_tgt_trans,trainable_params_tgt_trans),
                                                                 global_step=self.global_step)
        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)
        # Cannot update the other's network, update own network only
        self.loss_src_trans = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits_src_trans, labels=self.y_src_trans))
        trainable_params_src_trans = [self.item_emb_weight_src,
                                self.f1_src_weight, self.f1_src_bias,
                                self.f3_src_weight, self.f3_src_bias, self.item_bias_src]
        trainable_params_src_trans.extend([self.transfer_matrix_t2s, self.transfer_matrix_s_cf])
        self.opt_src_trans = tf.train.AdamOptimizer(learning_rate=self.config['lr'])
        gradients_src_trans = tf.gradients(self.loss_src_trans, trainable_params_src_trans)
        clip_gradients_src_trans, _ = tf.clip_by_global_norm(gradients_src_trans, self.clip_norm)
        self.train_op_src_trans = self.opt_src_trans.apply_gradients(zip(clip_gradients_src_trans,trainable_params_src_trans),
                                                                     global_step=self.global_step)
        self.global_epoch_increment = tf.assign(self.global_epoch, self.global_epoch+1)

    # self.iter, (users, hist_i,items,ys,sl)
    def train_tgt(self, sess, uij):
        loss_tgt, _ = sess.run([self.loss_tgt,self.train_op_tgt], feed_dict={
            self.u_tgt: uij[0], self.hist_i_tgt: uij[1], self.i_tgt: uij[2], self.y_tgt: uij[3], self.sl_tgt: uij[4]
        })
        return loss_tgt

    # self.iter, (users, hist_i,items,ys,sl)
    def train_src(self, sess, uij):
        loss_src, _ = sess.run([self.loss_src, self.train_op_src], feed_dict={self.u_src: uij[0],
                                self.hist_i_src: uij[1], self.i_src: uij[2], self.y_src: uij[3],self.sl_src: uij[4],

        })
        return loss_src

    # self.iter, (users, hist_i,items,ys,sl, y_gender,y_age,y_occupation)
    def train_attack_fake_source(self, sess, uij):  # adversarial source network
        loss_attack_fake_src, loss_attack_src, _ = sess.run(
                [self.loss_attack_fake_src, self.loss_attack_src, self.train_op_attack_fake_src],
                feed_dict={self.u_src: uij[0], self.hist_i_src: uij[1], self.i_src: uij[2], self.y_src: uij[3],
                           self.sl_src: uij[4], self.y_gender: uij[5],self.y_age: uij[6],self.y_occupation: uij[7]
        })
        return loss_attack_fake_src, loss_attack_src

    # self.iter, (users, hist_i,items,ys,sl, y_gender,y_age,y_occupation)
    def train_set_attack_real_source(self, sess, uij):  # adversarial source network
        loss_attack_real, loss_gender_real, loss_age_real, loss_occupation_real, _ = sess.run(
                [self.loss_attack_real, self.loss_gender_real, self.loss_age_real, self.loss_occupation_real,
                 self.train_op_attack_real],
                feed_dict={self.u_src: uij[0], self.hist_i_src: uij[1], self.i_src: uij[2], self.y_src: uij[3],
                           self.sl_src: uij[4], self.y_gender: uij[5],self.y_age: uij[6],self.y_occupation: uij[7]
        })
        return loss_attack_real, loss_gender_real, loss_age_real, loss_occupation_real

    # (users, hist_i_src,items_src,sl_src, ys_src, hist_i_tgt,items_tgt,sl_tgt, ys_tgt)
    def train_transfer_tgt_joint(self, sess, uij):
        loss_tgt_trans, _ = sess.run([self.loss_tgt_trans, self.train_op_tgt_trans], feed_dict={self.u_trans: uij[0],
            self.hist_i_src_trans: uij[1], self.i_src_trans: uij[2], self.sl_src_trans: uij[3], self.y_src_trans: uij[4],
            self.hist_i_tgt_trans: uij[5], self.i_tgt_trans: uij[6], self.sl_tgt_trans: uij[7], self.y_tgt_trans: uij[8],
        })
        return loss_tgt_trans

    # (users, hist_i_src,items_src,sl_src, ys_src, hist_i_tgt,items_tgt,sl_tgt, ys_tgt)
    def train_transfer_src_joint(self, sess, uij):
        loss_src_trans, _ = sess.run([self.loss_src_trans, self.train_op_src_trans], feed_dict={self.u_trans: uij[0],
            self.hist_i_src_trans: uij[1], self.i_src_trans: uij[2], self.sl_src_trans: uij[3], self.y_src_trans: uij[4],
            self.hist_i_tgt_trans: uij[5], self.i_tgt_trans: uij[6], self.sl_tgt_trans: uij[7], self.y_tgt_trans: uij[8],
        })
        return loss_src_trans

    #  (users,hist_i,items,sl, items_tgt, j_parent)
    #  compute the score only, NO need to compute the label
    def test(self, sess, uij):
        u_test, j_test, score_j_test, j_parent = sess.run([self.u_test, self.j_test, self.score_j_test, self.j_parent],
            feed_dict={self.u_test: uij[0], self.hist_i_test: uij[1], self.i_test: uij[2], self.sl_test: uij[3],
            self.j_test:uij[4], self.j_parent:uij[5],
        })
        return u_test, j_test, score_j_test, j_parent

    # self.iter, (users, hist_i,items,ys,sl, y_gender, y_age,y_occupation)
    def test_attack(self, sess, uij):
        u_src, score_gender, score_age, score_occupation = sess.run([self.u_src, self.score_gender, self.score_age, self.score_occupation],
                feed_dict={self.u_src: uij[0], self.hist_i_src: uij[1], self.i_src: uij[2], self.y_src: uij[3],self.sl_src: uij[4],
                self.y_gender: uij[5], self.y_age: uij[6], self.y_occupation: uij[7]
        })
        return u_src, score_gender, score_age, score_occupation


def attention_dot(query, keys, keys_length):
    # query=[B,H], keys=[B,T,H], keys_length=[B]
    query_3d = tf.reshape(query, [-1, 1, query.get_shape().as_list()[-1]])  # [B,1,H], H: emd_dim
    # Compute attention by a dot product
    outputs = tf.matmul(query_3d, keys, adjoint_b=True)  # [B,1,T], attention weights
    # mask
    keys_masks = tf.sequence_mask(keys_length, tf.shape(keys)[1])  # [B,T], actual hist_len
    keys_masks = tf.expand_dims(input=keys_masks, axis=1)  # [B,1,T]
    paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)  # -4294967295
    outputs = tf.where(condition=keys_masks, x=outputs, y=paddings)  # [B,1 T], True for x and False for y
    # scale with dimension sqrt(H)
    outputs = tf.scalar_mul(math.sqrt(keys.get_shape().as_list()[-1]), outputs)
    # activation
    outputs = tf.nn.softmax(outputs)  # [B,1,T]. filter out the masks (they become zero)
    # Weighted sum
    outputs = tf.matmul(outputs, keys)  # B,1,T] x [B,T,H] --> [B,1,H]

    return outputs
