import tensorflow as tf
import numpy as np
from textcnn import TextCNN,LinearLayer

def tf_cov(x,y):
    mean_x = tf.reduce_mean(x, axis=0, keepdims=True)
    mean_y = tf.reduce_mean(y, axis=0, keepdims=True)
    cov_xx = tf.reduce_sum(tf.multiply(x-mean_x, y-mean_y))
    ss_x=tf.reduce_sum(tf.multiply(x-mean_x, x-mean_x))
    ss_y=tf.reduce_sum(tf.multiply(y-mean_y, y-mean_y))
    ss=tf.sqrt(tf.multiply(ss_x,ss_y))
    return tf.divide(cov_xx,ss)


class Shared_CNN(object):
    def __init__(
      self, sequence_length, num_classes, vocab_size,
      embedding_size, filter_sizes, num_filters,critic_units,l2_param,wd_param,gp_param, l2_reg_lambda=0.0):
        self.input_x1 = tf.placeholder(tf.int32, [None, sequence_length], name="input_x1")
        self.input_y1 = tf.placeholder(tf.float32, [None, num_classes], name="input_y1")
        self.input_x2 = tf.placeholder(tf.int32, [None, sequence_length], name="input_x2")
        self.input_y2 = tf.placeholder(tf.float32, [None, num_classes], name="input_y2")
        self.dropout_rate = tf.placeholder(tf.float32, name="dropout_rate")

        # Keeping track of l2 regularization loss (optional)
        l2_loss = tf.constant(0.0)
        num_filters_total = num_filters * len(filter_sizes)
        # Embedding layer
        with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars1 = tf.nn.embedding_lookup(self.W, self.input_x1)
            self.embedded_chars_expanded1 = tf.expand_dims(self.embedded_chars1, -1)
            self.embedded_chars2= tf.nn.embedding_lookup(self.W, self.input_x2)
            self.embedded_chars_expanded2 = tf.expand_dims(self.embedded_chars2, -1)
            
        with tf.variable_scope("shared") as scope:

            self.o1=TextCNN(self.embedded_chars_expanded1,filter_sizes,embedding_size,num_filters,sequence_length)
            scope.reuse_variables()
            self.o2=TextCNN(self.embedded_chars_expanded2,filter_sizes,embedding_size,num_filters,sequence_length)    

        with tf.name_scope("output1"):
            self.h_drop1 = tf.nn.dropout(self.o1, self.dropout_rate)
            linear1 = LinearLayer('linear1', 1, True)
            self.scores1,l2_loss1 = linear1(self.h_drop1)
            

        with tf.name_scope("loss1"):
            self.loss1 = tf.losses.mean_squared_error(labels=self.input_y1,predictions=self.scores1)+ l2_reg_lambda * l2_loss1
        
        with tf.name_scope("pcc1"):
            self.pcc1=tf_cov(self.scores1,self.input_y1)


        with tf.name_scope("output2"):
            self.h_drop2 = tf.nn.dropout(self.o2, self.dropout_rate)
            linear2 = LinearLayer('linear2', 1, True)
            self.scores2,l2_loss2 = linear2(self.h_drop2)

        with tf.name_scope("loss2"):
            self.loss2 = tf.losses.mean_squared_error(labels=self.input_y2,predictions=self.scores2)+ l2_reg_lambda * l2_loss2

        with tf.name_scope("pcc2"):
            self.pcc2=tf_cov(self.scores2,self.input_y2)

        with tf.name_scope("loss"):
            self.loss = self.loss1 + self.loss2

        h1_whole = tf.concat([self.o1, self.o2], axis=0)
        with tf.variable_scope('Critic') as scope:

            l1 = tf.layers.dense(
                inputs=h1_whole,
                units=critic_units,  
                activation=tf.nn.relu,
                kernel_initializer=tf.random_normal_initializer(0., .1),  # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='critic_l1',
                reuse=tf.AUTO_REUSE
            )

            critic_out = tf.layers.dense(
                inputs=l1,
                units=1,  
                activation=None,  #act=tf.identity
                kernel_initializer=tf.random_normal_initializer(0., .1),  # weights
                bias_initializer=tf.constant_initializer(0.1),  # biases
                name='critic_l2',
                reuse=tf.AUTO_REUSE
            )

        input_x1_mask = 1.0 - tf.cast(tf.equal(self.input_x1, tf.zeros_like(self.input_x1, dtype=tf.int32)),
                                       tf.float32)
        input_x2_mask = 1.0 - tf.cast(tf.equal(self.input_x2, tf.zeros_like(self.input_x2, dtype=tf.int32)),
                                       tf.float32)
        s_mask = tf.slice(input_x1_mask, [0, 0], [-1, 1])
        t_mask=tf.slice(input_x2_mask,[0,0],[-1,1])
        self.t_mask=tf.concat([1-s_mask,t_mask],axis=0)
        self.s_mask = tf.concat([s_mask, 1-t_mask],axis=0)

        critic_s = critic_out*self.s_mask
        critic_t = critic_out*self.t_mask

        self.wd_loss = tf.reduce_mean(critic_s) - tf.reduce_mean(critic_t)

        gradients = tf.gradients(critic_out, [h1_whole])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
        self.all_wd_loss=-self.wd_loss + gp_param * gradient_penalty

        self.theta_D = [v for v in tf.global_variables() if 'critic' in v.name]
        self.theta_G=[v for v in tf.global_variables() if v not in self.theta_D]
    
        all_variables = tf.trainable_variables()
        self.l2_loss = l2_param * tf.add_n([tf.nn.l2_loss(v) for v in all_variables if 'bias' not in v.name])
        self.clf_loss= self.loss- l2_reg_lambda * l2_loss2- l2_reg_lambda * l2_loss1 
        self.total_loss =self.clf_loss+self.l2_loss + wd_param * self.wd_loss






    