import tensorflow as tf

def model_BiLSTM(x, lstmUnitNum, forget_bias, input_keep_prob, output_keep_prob, attn_length=-1):
    lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                activation=None, reuse=None)
    lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(lstmUnitNum, forget_bias=forget_bias, state_is_tuple=True,
                                                activation=None, reuse=None)
    if attn_length != -1:
        lstm_fw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_fw_cell, attn_length, state_is_tuple=True)
        lstm_bw_cell = tf.contrib.rnn.AttentionCellWrapper(lstm_bw_cell, attn_length, state_is_tuple=True)
    lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_fw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)
    lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_bw_cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)

    # The input 'x' is [batch_size, n_steps, (n_input)n_dimensions]
    hiddens, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_fw_cell, cell_bw=lstm_bw_cell, inputs=x,
                                                     dtype=tf.float32)
    # fw:hiddens[0], bw:hiddens[1]
    hiddens = tf.concat(hiddens, axis=2)
    # The hiddens is [batch_size, n_steps, lstmUnitNum*2]
    # hiddens = tf.transpose(hiddens, [1, 0, 2])
    return hiddens

def model_CNN(input_r, windowSize, filterNum, keepratio):
    w_conv = tf.Variable(tf.random_normal([input_r.shape[1].value, windowSize, 1, filterNum], stddev=0.1))
    b_conv = tf.Variable(tf.random_normal([filterNum], stddev=0.1))
    conv = tf.nn.conv2d(input_r, w_conv, strides=[1, 1, 1, 1], padding='VALID')
    conv = tf.nn.relu(tf.nn.bias_add(conv, b_conv))
    conv = tf.nn.dropout(conv, keepratio)
    pool = tf.reduce_max(conv, axis=[2])
    return pool

def model_RCNN(x, lstmUnitNum, filterNum, forget_bias, keepratio_cnn, input_keep_prob, output_keep_prob):
    rnn_output = model_BiLSTM(x, lstmUnitNum, forget_bias, input_keep_prob, output_keep_prob)
    cnn_input = tf.reshape(rnn_output, shape=[rnn_output.shape[0].value, lstmUnitNum * 2, -1, 1])
    p3 = model_CNN(cnn_input, 3, filterNum, keepratio_cnn)
    p5 = model_CNN(cnn_input, 5, filterNum, keepratio_cnn)
    p3 = tf.reshape(p3, shape=[p3.shape[0].value, filterNum])
    p5 = tf.reshape(p5, shape=[p5.shape[0].value, filterNum])
    output_cnn = tf.concat([p3, p5], axis=1)
    return output_cnn
