import tensorflow as tf
import numpy as np
from dam import dam


class SharedNN_DAM(object):

    def __init__(
            self, sequence_length, num_classes, vocab_size, embedding_size, num_units,
            critic_units,l2_param,wd_param,gp_param,l2_reg_lambda=0.0):

        self.input_x_a = tf.placeholder(tf.int32, [None, sequence_length], name="input_x_a")
        self.input_x_b = tf.placeholder(tf.int32, [None, sequence_length], name="input_x_b")
        self.input_y1 = tf.placeholder(tf.float32, [None, 2], name="input_y1")

        self.input_x_c = tf.placeholder(tf.int32, [None, sequence_length], name="input_x_c")
        self.input_x_d = tf.placeholder(tf.int32, [None, sequence_length], name="input_x_d")
        self.input_y2 = tf.placeholder(tf.float32, [None, 2], name="input_y2")

        self.dropout_rate = tf.placeholder(tf.float32, name="dropout_rate")

        input_x_a_mask = 1.0 - tf.cast(tf.equal(self.input_x_a, tf.zeros_like(self.input_x_a, dtype=tf.int32)),
                                       tf.float32)
        input_x_b_mask = 1.0 - tf.cast(tf.equal(self.input_x_b, tf.zeros_like(self.input_x_b, dtype=tf.int32)),
                                       tf.float32)
        input_x_c_mask = 1.0 - tf.cast(tf.equal(self.input_x_c, tf.zeros_like(self.input_x_c, dtype=tf.int32)),
                                       tf.float32)
        input_x_d_mask = 1.0 - tf.cast(tf.equal(self.input_x_d, tf.zeros_like(self.input_x_d, dtype=tf.int32)),
                                       tf.float32)

        with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="embedding_W", trainable=True)
            input_x_a_embed = tf.nn.embedding_lookup(self.W, self.input_x_a)
            input_x_b_embed = tf.nn.embedding_lookup(self.W, self.input_x_b)
            input_x_c_embed = tf.nn.embedding_lookup(self.W, self.input_x_c)
            input_x_d_embed = tf.nn.embedding_lookup(self.W, self.input_x_d)

        l2_loss1 = tf.constant(0.0)
        l2_loss2 = tf.constant(0.0)

        with tf.variable_scope("shared") as scope:
            self.o1 = dam(input_x_a_embed, input_x_b_embed, input_x_a_mask, input_x_b_mask, self.dropout_rate,
                          num_units)

            scope.reuse_variables()
            self.o2 = dam(input_x_c_embed, input_x_d_embed, input_x_c_mask, input_x_d_mask, self.dropout_rate,
                          num_units)

        with tf.name_scope("output1"):
            dropout1 = tf.layers.dropout(inputs=self.o1, rate=self.dropout_rate)
            self.scores1 = tf.layers.dense(inputs=dropout1, units=2, activation=None, use_bias=True,
                                           kernel_initializer=None, bias_initializer=None,
                                           name='scores1')
            self.predictions1 = tf.argmax(self.scores1, 1, name="predictions1")

        l2_loss_vars_shared = [v for v in tf.trainable_variables() if 'F1' in v.name or 'F2' in v.name or \
                               'G1' in v.name or 'G2' in v.name]

        l2_loss_vars1 = [v for v in tf.trainable_variables() if 'scores1' in v.name]
        
        l2_loss1 += sum(map(tf.nn.l2_loss, l2_loss_vars_shared + l2_loss_vars1))

        with tf.name_scope("loss1"):
            self.prob1 = tf.nn.softmax(self.scores1)
            self.losses1 = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores1, labels=self.input_y1)
            self.loss1 = tf.reduce_mean(self.losses1) + l2_reg_lambda * l2_loss1

        with tf.name_scope("accuracy1"):
            correct_predictions1 = tf.equal(self.predictions1, tf.argmax(self.input_y1, 1))
            self.accuracy1 = tf.reduce_mean(tf.cast(correct_predictions1, "float"), name="accuracy1")

        with tf.name_scope("output2"):
            dropout2 = tf.layers.dropout(inputs=self.o2, rate=self.dropout_rate)
            self.scores2 = tf.layers.dense(inputs=dropout2, units=2, activation=None, use_bias=True,
                                           kernel_initializer=None, bias_initializer=None,
                                           name='scores2')
            self.predictions2 = tf.argmax(self.scores2, 1, name="predictions2")

        l2_loss_vars2 = [v for v in tf.trainable_variables() if 'scores2' in v.name]
        l2_loss2 += sum(map(tf.nn.l2_loss, l2_loss_vars_shared + l2_loss_vars2))
        with tf.name_scope("loss2"):
            self.losses2 = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores2, labels=self.input_y2)
            self.loss2 = tf.reduce_mean(self.losses2) + l2_reg_lambda * l2_loss2

        with tf.name_scope("accuracy2"):
            self.correct_predictions2 = tf.equal(self.predictions2, tf.argmax(self.input_y2, 1))
            self.accuracy2 = tf.reduce_mean(tf.cast(self.correct_predictions2, "float"), name="accuracy2")

        with tf.name_scope('auc2'):
            self.prob2 = tf.nn.softmax(self.scores2)
            self.auc2, self.update_auc_op = tf.metrics.auc(tf.argmax(self.input_y2, 1), self.prob2[:, 1])
        with tf.name_scope("loss"):
            self.loss = self.loss1 + self.loss2

        h1_whole = tf.concat([self.o1, self.o2], axis=0)

        LR_C = 0.01  
        with tf.variable_scope('Critic') as scope:

            l1 = tf.layers.dense(
                inputs=h1_whole,
                units=critic_units, 
                activation=tf.nn.relu,
                kernel_initializer=tf.random_normal_initializer(0., .1),  # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='critic_l1',
                reuse=tf.AUTO_REUSE
            )

            critic_out = tf.layers.dense(
                inputs=l1,
                units=1,  # output units
                activation=None,  #act=tf.identity
                kernel_initializer=tf.random_normal_initializer(0., .1),  # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='critic_l2',
                reuse=tf.AUTO_REUSE
            )


        t_mask=tf.slice(input_x_c_mask,[0,0],[-1,1])
        s_mask = tf.slice(input_x_a_mask, [0, 0], [-1, 1])

        self.t_mask=tf.concat([1-s_mask,t_mask],axis=0)
        self.s_mask = tf.concat([s_mask, 1-t_mask],axis=0)

        critic_s = critic_out*self.s_mask
        critic_t = critic_out*self.t_mask

        self.critic_s=critic_s
        self.critic_t=critic_t
        self.critic_out=critic_out

        self.wd_loss = tf.reduce_mean(critic_s) - tf.reduce_mean(critic_t)
        gradients = tf.gradients(critic_out, [h1_whole])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)


        self.all_wd_loss=-self.wd_loss + gp_param * gradient_penalty

        self.theta_D = [v for v in tf.global_variables() if 'critic' in v.name]
        self.theta_G=[v for v in tf.global_variables() if v not in self.theta_D]
    
        all_variables = tf.trainable_variables()

        self.l2_loss = l2_param * tf.add_n([tf.nn.l2_loss(v) for v in all_variables if 'bias' not in v.name])
        self.total_loss =  tf.reduce_mean(self.losses1)+ tf.reduce_mean(self.losses2) +self.l2_loss + wd_param * self.wd_loss
