#!/usr/bin/env python
# -*- coding:utf8 -*-

# ================================================================================
# Copyright XXXXXXXXXXXX. All Rights Reserved.
# ================================================================================


def prepare_encoder_input(src_wids, src_sids, src_masks, params):
    src_vocab_size = params["src_vocab_size"]
    script_vocab_size = params["script_vocab_size"]
    src_word_vocab_size = params["src_word_vocab_size"]
    hidden_size = params["hidden_size"]
    number_of_classes = params["number_of_classes"]

    with tf.variable_scope('Source_Side'):
        src_emb = common_layers.embedding(src_wids, src_vocab_size, hidden_size)
    src_emb *= hidden_size**0.5
    
    encoder_self_attention_bias = common_attention.attention_bias_ignore_padding(1-src_masks)
    encoder_input = common_attention.add_timing_signal_1d(src_emb)
    encoder_input = tf.multiply(encoder_input,tf.expand_dims(src_masks,2))
    return encoder_input,encoder_self_attention_bias

def layer_process(x, y, flag, dropout):
    if flag == None:
        return y
    for c in flag:
        if c == 'a':
            y = x+y
        elif c == 'n':
            y = common_layers.layer_norm(y)
        elif c == 'd':
            y = tf.nn.dropout(y, 1.0 - dropout)
    return y

def transformer_ffn_layer(x, params):
    filter_size = params["filter_size"]
    hidden_size = params["hidden_size"]
    relu_dropout = params["relu_dropout"]
    return common_layers.conv_hidden_relu(
            x,
            filter_size,
            hidden_size,
            dropout=relu_dropout)

def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        mask,
                        params={},
                        name="encoder"):
    num_hidden_layers = params["num_hidden_layers"]
    hidden_size = params["hidden_size"]
    num_heads = params["num_heads"]
    prepost_dropout = params["prepost_dropout"]
    attention_dropout = params["attention_dropout"]
    preproc_actions = params['preproc_actions']
    postproc_actions = params['postproc_actions']
    x = encoder_input
    mask = tf.expand_dims(mask,2)
    with tf.variable_scope(name):
        for layer in xrange(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                o,w = common_attention.multihead_attention(
                        layer_process(None,x,preproc_actions,prepost_dropout),
                        None,
                        encoder_self_attention_bias,
                        hidden_size,
                        hidden_size,
                        hidden_size,
                        num_heads,
                        attention_dropout,
                        summaries=False,
                        name="encoder_self_attention")
                x = layer_process(x,o,postproc_actions,prepost_dropout)
                o = transformer_ffn_layer(layer_process(None,x,preproc_actions,prepost_dropout), params)

                x = layer_process(x,o,postproc_actions,prepost_dropout)
                x = tf.multiply(x,mask)
        return layer_process(None,x,preproc_actions,prepost_dropout)


def output_layer(src_wids, src_sids, shift_src_masks, params):
    encoder_input, encoder_self_attention_bias = prepare_encoder_input(src_wids, src_sids, shift_src_masks, params)
    encoder_input = tf.nn.dropout(encoder_input,\
                                        1.0 - params['prepost_dropout'])
    encoder_output = transformer_encoder(encoder_input, encoder_self_attention_bias,\
        shift_src_masks, params)
    hidden_size = params["hidden_size"]
    number_of_classes = params["number_of_classes"]

    domain_classifier_output = tf.reduce_sum(encoder_output*tf.expand_dims(shift_src_masks,2), 1) / tf.reduce_sum(shift_src_masks, 1, keep_dims=True) 
    domain_classifier_output = tf.layers.dense(domain_classifier_output, hidden_size, activation=tf.tanh, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

    if use_user_preference_type == 'revision':
        logits = logits * src_sids
    elif use_user_preference_type == 'representation':
        model_logits = tf.nn.softmax(logits)
        domain_classifier_output = tf.concat([model_logits, src_sids], 1)

        with tf.variable_scope('UserDomainClassifier', reuse=tf.AUTO_REUSE):
            tag_vocab_embedding_tensor = tf.get_variable('C', [number_of_classes, \
                    number_of_classes*2], initializer=\
                    tf.random_normal_initializer(0.0, hidden_size**-0.5, dtype=tf.float32))
            tag_vocab_bias = tf.get_variable("C_bias", shape=[number_of_classes], initializer=tf.zeros_initializer())

        logits = tf.nn.bias_add(tf.matmul(domain_classifier_output, tag_vocab_embedding_tensor, transpose_b=True), tag_vocab_bias)

    dist = tf.nn.softmax(logits)
    dist = tf.clip_by_value(dist, 1e-8, 1.0-1e-8)
    return logits, dist


def get_loss(features, contexts, labels, params):
    last_padding = tf.zeros([tf.shape(features)[0],1],tf.int64)
    src_wids = tf.concat([features,last_padding],1) 
    src_sids = contexts
    src_masks = tf.to_float(tf.not_equal(src_wids,0))
    shift_src_masks = src_masks[:,:-1]
    shift_src_masks = tf.pad(shift_src_masks,[[0,0],[1,0]],constant_values=1)

    logits, dist = output_layer(src_wids, src_sids, shift_src_masks, params)
    number_of_classes = params["number_of_classes"]
    targets = tf.one_hot(tf.cast(tf.squeeze(labels,1), tf.int32), depth=number_of_classes) 

    
    xentropy_weighted = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)
    loss = tf.reduce_sum(xentropy_weighted) / tf.cast(tf.shape(features)[0], dtype=tf.float32)

    res = tf.argmax(dist, tf.rank(dist) -1)
    accuracy = tf.metrics.accuracy(labels=labels,
                        predictions=res,
                        name='acc_op')
    return loss, accuracy

def transformer_model_fn(features, labels, mode, params):
    with tf.variable_scope('LIDModel') as var_scope:
        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.997, epsilon=1e-09)
            optimizer = tf.contrib.estimator.clip_gradients_by_norm(optimizer, gradient_clip_value)           
            
            features, contexts = features['src'], features['scr']
            loss, train_acc = get_loss(features, contexts, labels, params)
            grads = optimizer.compute_gradients(loss)
            train_op = optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())

            return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

        if mode == tf.estimator.ModeKeys.EVAL:
            features, contexts = features['src'], features['scr']
            last_padding = tf.zeros([tf.shape(features)[0],1],tf.int64)
            src_wids = tf.concat([features,last_padding],1)
            src_sids = contexts
            src_masks = tf.to_float(tf.not_equal(src_wids,0))
            shift_src_masks = src_masks[:,:-1]
            shift_src_masks = tf.pad(shift_src_masks,[[0,0],[1,0]],constant_values=1)

            _, prob = output_layer(src_wids, src_sids, shift_src_masks, params)
            res = tf.argmax(prob, tf.rank(prob) -1)
            accuracy = tf.metrics.accuracy(labels=labels,
                              predictions=res,
                              name='acc_op')
 
            return tf.estimator.EstimatorSpec(mode=mode, loss=tf.constant(0.0))
