import tensorflow as tf
import numpy as np

def dam(a, b, a_mask, b_mask, dropout_rate, num_units):
    
    a_mask_original, b_mask_original = a_mask, b_mask
    a_mask = tf.tile(tf.expand_dims(a_mask, 2), [1, 1, num_units])
    b_mask = tf.tile(tf.expand_dims(b_mask, 2), [1, 1, num_units])
    a = a * tf.expand_dims(a_mask_original, 2)
    b = b * tf.expand_dims(b_mask_original, 2)

    with tf.name_scope("attend"):
        a = tf.layers.dropout(inputs=a, rate=dropout_rate)
        b = tf.layers.dropout(inputs=b, rate=dropout_rate)
        
        F1a = tf.layers.dense(inputs=a, units=num_units, activation=tf.nn.relu, use_bias=True,
                                          kernel_initializer=None, bias_initializer=None,
                                          name='F1')
        F1b = tf.layers.dense(inputs=b, units=num_units, activation=tf.nn.relu, use_bias=True,
                              kernel_initializer=None, bias_initializer=None,
                              name='F1', reuse=True)
        F1a = F1a * a_mask
        F1b = F1b * b_mask
        
        F1a = tf.layers.dropout(inputs=F1a, rate=dropout_rate)
        F1b = tf.layers.dropout(inputs=F1b, rate=dropout_rate)

        Fa = tf.layers.dense(inputs=F1a, units=num_units, activation=tf.nn.relu, use_bias=True,
                             kernel_initializer=None, bias_initializer=None,
                             name='F2')
        Fb = tf.layers.dense(inputs=F1b, units=num_units, activation=tf.nn.relu, use_bias=True,
                             kernel_initializer=None, bias_initializer=None,
                             name='F2', reuse=True)
        
        Fa = Fa * a_mask
        Fb = Fb * b_mask

        Fb_transposed = tf.transpose(Fb, [0, 2, 1])
        attention_weights = tf.matmul(Fa, Fb_transposed, name='attention_weights')
        attention_weights_transposed = tf.transpose(attention_weights, [0, 2, 1])
        
        attention_weights1 = attention_weights - tf.reduce_max(attention_weights, axis=-1, keep_dims=True)
        attention_weights2 = attention_weights_transposed - tf.reduce_max(attention_weights_transposed, axis=-1, keep_dims=True)
                
        attention_mask = tf.matmul(a_mask, tf.transpose(b_mask, [0, 2, 1]))
        
        attention_weights_exp1 = tf.exp(attention_weights1) * attention_mask
        attention_soft1 = attention_weights_exp1 / (tf.reduce_sum(attention_weights_exp1, axis=-1, keep_dims=True) + 1e-8)
        
        attention_weights_exp2 = tf.exp(attention_weights2) * tf.transpose(attention_mask, [0, 2, 1])
        attention_soft2 = attention_weights_exp2 / (tf.reduce_sum(attention_weights_exp2, axis=-1, keep_dims=True) + 1e-8)
        
        beta = tf.matmul(attention_soft1, b)
        alpha = tf.matmul(attention_soft2, a)
    
    with tf.name_scope("compare"):
        a_beta = tf.concat([a, beta], 2)
        b_alpha = tf.concat([b, alpha], 2)
        
        a_beta = tf.layers.dropout(inputs=a_beta, rate=dropout_rate)
        b_alpha = tf.layers.dropout(inputs=b_alpha, rate=dropout_rate)

        v1i1 = tf.layers.dense(inputs=a_beta, units=num_units, activation=tf.nn.relu, use_bias=True,
                               kernel_initializer=None, bias_initializer=None,
                               name='G1')
        v2j1 = tf.layers.dense(inputs=b_alpha, units=num_units, activation=tf.nn.relu, use_bias=True,
                               kernel_initializer=None, bias_initializer=None,
                               name='G1', reuse=True)
        
        v1i1 = tf.layers.dropout(inputs=v1i1, rate=dropout_rate)
        v2j1 = tf.layers.dropout(inputs=v2j1, rate=dropout_rate)

        v1i = tf.layers.dense(inputs=v1i1, units=num_units, activation=tf.nn.relu, use_bias=True,
                              kernel_initializer=None, bias_initializer=None,
                              name='G2')
        v2j = tf.layers.dense(inputs=v2j1, units=num_units, activation=tf.nn.relu, use_bias=True,
                              kernel_initializer=None, bias_initializer=None,
                              name='G2', reuse=True)

    with tf.name_scope("aggregate"):
        v1i = v1i * a_mask
        v2j = v2j * b_mask
        
        v1 = tf.reduce_sum(v1i, axis=1)
        v2 = tf.reduce_sum(v2j, axis=1)
        
        v1_max = tf.reduce_max(v1i, 1)
        v2_max = tf.reduce_max(v2j, 1)

        v1_v2 = tf.concat([v1, v2, v1_max, v2_max], 1)

    return v1_v2