import tensorflow as tf
import attention
import layers
import optimizers

import sys

tf.logging.set_verbosity(tf.logging.INFO)

def _layer_process(x, mode):
    if not mode or mode == "none":
        return x
    elif mode == "layer_norm":
        return layers.layer_norm(x)
    else:
        raise ValueError("Unknown mode %s" % mode)

def _residual_fn(x, y, keep_prob=None):
    if keep_prob and keep_prob < 1.0:
        y = tf.nn.dropout(y, keep_prob)
    return x + y

def transformer_ffn_layer(x, params, name=None):
    filter_size = params["filter_size"]
    hidden_size = params["hidden_size"]
    keep_prob = 1.0 - params["relu_dropout"]
    with tf.variable_scope(name, default_name="ffn_layer", values=[x]):
        with tf.variable_scope("input_layer"):
            hidden = layers.linear(x, filter_size, True, True)
            hidden = tf.nn.relu(hidden)

        if keep_prob and keep_prob < 1.0:
            hidden = tf.nn.dropout(hidden, keep_prob)

        with tf.variable_scope("output_layer"):
            output = layers.linear(hidden, hidden_size, True, True)
        
        return output

def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        mask,
                        params={},
                        name="encoder"):
    num_encoder_layers = params["num_encoder_layers"]
    hidden_size = params["hidden_size"]
    num_heads = params["num_heads"]
    residual_dropout = params["residual_dropout"]
    attention_dropout = params["attention_dropout"]
    layer_preproc = params["layer_preproc"]
    layer_postproc = params["layer_postproc"]
    x = encoder_input
    mask = tf.expand_dims(mask,2)
    regularizer = tf.contrib.layers.l1_l2_regularizer(scale_l1=params["scale_l1"],
                                                      scale_l2=params["scale_l2"])
    with tf.variable_scope(name, regularizer=regularizer):
        for layer in range(num_encoder_layers):
            with tf.variable_scope("layer_%d" % layer):
                o,w = attention.multihead_attention(
                        _layer_process(x, layer_preproc),
                        None,
                        encoder_self_attention_bias,
                        hidden_size,
                        hidden_size,
                        hidden_size,
                        num_heads,
                        attention_dropout,
                        name="encoder_self_attention")
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)
                
                o = transformer_ffn_layer(_layer_process(x, layer_preproc), params)
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)
                        
                x = tf.multiply(x,mask)
        return _layer_process(x, layer_preproc)


def reparameterize(mu, log_sigma):
    if log_sigma is None:
        z = mu
    else:
        std = tf.exp(log_sigma * 0.5)
        eps = tf.random_normal(shape=tf.shape(std),mean=0.0,stddev=1.0)
        z = mu + eps * std
    return z


def semantic_extractor(encoder_output,
                       mask,
                       params={},
                       name="extractor"):
    hidden_size = params["hidden_size"]
    conv_filter_sizes = params["conv_filter_sizes"]
    num_conv_filters = params["num_conv_filters"]
    residual_dropout = params["residual_dropout"]
    attention_dropout = params["attention_dropout"]
    layer_preproc = params["layer_preproc"]
    layer_postproc = params["layer_postproc"]
    x = tf.expand_dims(encoder_output, axis=-1)
    pooled_outputs = []
    mask = tf.expand_dims(mask,2)
    initializer = tf.truncated_normal_initializer(stddev=0.01, seed=1234)
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        for i, conv_filter_size in enumerate(conv_filter_sizes):
            with tf.variable_scope("conv_maxpool_%d" % i):
                # Convolution Layer
                filter_shape = [conv_filter_size, hidden_size, 1, num_conv_filters]
                W = tf.get_variable('W', shape=filter_shape, initializer=initializer)
                b = tf.get_variable('b', shape=[num_conv_filters], initializer=initializer)
                conv = tf.nn.conv2d(
                    x,
                    W,
                    strides=[1, 1, 1, 1],
                    padding="VALID",
                    name="conv")
                # Apply nonlinearity
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
                # Maxpooling over the outputs
                pooled = tf.reduce_max(h, axis=1, name="pool")
                pooled_outputs.append(pooled)
        num_filters_total = num_conv_filters * len(conv_filter_sizes)
        h_pool = tf.concat(pooled_outputs, 2)
        if prepost_dropout != 0.0:
      	    h_pool = tf.nn.dropout(h_pool, 1.0 - prepost_dropout)
        
        mu = layers.linear(h_pool, hidden_size, True, True, scope="mu")
        log_sigma = layers.linear(h_pool, hidden_size, True, True, scope="sigma")
        o = reparameterize(mu, log_sigma)
        return _layer_process(o, layer_preproc), mu, log_sigma


def gated_encoder_semantic(encoder_output, 
			   semantic_output, 
			   mask, 
			   params={},
			   name="gated_encoder_semantic"):
    hidden_size = params["hidden_size"]
    residual_dropout = params["residual_dropout"]
    attention_dropout = params["attention_dropout"]
    layer_preproc = params["layer_preproc"]
    layer_postproc = params["layer_postproc"]
    mask = tf.expand_dims(mask,2)
    
    x = encoder_output
    seq_len = tf.shape(x)[1]
    y = tf.tile(semantic_output, [1, seq_len, 1])
    y = tf.multiply(y, mask)
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    	gate_for_encoder = layers.linear(x, hidden_size, True, True, scope="gate_for_encoder")
        gate_for_semantic = layers.linear(y, hidden_size, True, True, scope="gate_for_semantic")
        gate = tf.nn.sigmoid(gate_for_encoder + gate_for_semantic)
        o = gate * gate_for_encoder + (1.0 - gate) * gate_for_semantic
        o = tf.multiply(o, mask)
        return _layer_process(o, layer_preproc)


def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        states_key=None,
                        states_val=None,
                        params={},
                        name="decoder"):
    num_decoder_layers = params["num_decoder_layers"]
    hidden_size = params["hidden_size"]
    num_heads = params["num_heads"]
    residual_dropout = params["residual_dropout"]
    attention_dropout = params["attention_dropout"]
    layer_preproc = params["layer_preproc"]
    layer_postproc = params["layer_postproc"]
    x = decoder_input
    with tf.variable_scope(name):
        for layer in range(num_decoder_layers):
            with tf.variable_scope("layer_%d" % layer):
                o,w = attention.multihead_attention(
                        _layer_process(x, layer_preproc),
                        None,
                        decoder_self_attention_bias,
                        hidden_size,
                        hidden_size,
                        hidden_size,
                        num_heads,
                        attention_dropout,
                        states_key=states_key,
                        states_val=states_val,
                        layer=layer,
                        name="decoder_self_attention")
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)
                        
                o,w = attention.multihead_attention(
                        _layer_process(x, layer_preproc),
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hidden_size,
                        hidden_size,
                        hidden_size,
                        num_heads,
                        attention_dropout,
                        name="encdec_attention")
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)
                        
                o = transformer_ffn_layer(_layer_process(x, layer_preproc), params)
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)
                        
        return  _layer_process(x, layer_preproc), w

def encoding_graph(features, params):
    src_vocab_size = params["src_vocab_size"]
    hidden_size = params["hidden_size"]
    
    initializer = tf.random_normal_initializer(0.0, hidden_size ** -0.5, dtype=tf.float32)

    if params["shared_source_target_embedding"]:
        with tf.variable_scope('Shared_Embedding', reuse=tf.AUTO_REUSE):
            src_embedding = tf.get_variable('Weights', 
                                            [src_vocab_size, hidden_size], 
                                            initializer=initializer)
    else:
        with tf.variable_scope('Source_Embedding'):
            src_embedding = tf.get_variable('Weights', 
                                            [src_vocab_size, hidden_size], 
                                            initializer=initializer)
    src_bias = tf.get_variable("encoder_input_bias", [hidden_size])

    eos_padding = tf.zeros([tf.shape(features)[0],1],tf.int64)
    src_seq = tf.concat([features, eos_padding],1)
    src_mask = tf.to_float(tf.not_equal(src_seq,0))
    src_mask = src_mask[:, :-1]
    src_mask = tf.pad(src_mask, [[0,0],[1,0]], constant_values=1)
    
    encoder_input = tf.gather(src_embedding, tf.cast(src_seq, tf.int32))
    encoder_input = encoder_input * (hidden_size ** 0.5)
    encoder_input = attention.add_timing_signal(encoder_input)
    encoder_input = tf.multiply(encoder_input, tf.expand_dims(src_mask,2))

    encoder_input = tf.nn.bias_add(encoder_input, src_bias)
    encoder_self_attention_bias = attention.attention_bias(src_mask, "masking")

    encoder_input = tf.nn.dropout(encoder_input, 1.0 - params['residual_dropout'])
    
    encoder_output = transformer_encoder(encoder_input, encoder_self_attention_bias, src_mask, params)

    return encoder_output, encoder_self_attention_bias

def decoding_graph(encoder_output, encoder_self_attention_bias, labels, params, is_training=True, states_key=None, states_val=None):
    trg_vocab_size = params["trg_vocab_size"]
    hidden_size = params["hidden_size"]

    initializer = tf.random_normal_initializer(0.0, hidden_size ** -0.5, dtype=tf.float32)

    if params["shared_source_target_embedding"]:
        with tf.variable_scope('Shared_Embedding', reuse=tf.AUTO_REUSE):
            trg_embedding = tf.get_variable('Weights',
                                            [trg_vocab_size, hidden_size],
                                            initializer=initializer)
    else:
        with tf.variable_scope('Target_Embedding'):
            trg_embedding = tf.get_variable('Weights',
                                            [trg_vocab_size, hidden_size],
                                            initializer=initializer)

    eos_padding = tf.zeros([tf.shape(labels)[0],1],tf.int64)
    trg_seq = tf.concat([labels, eos_padding],1)
    trg_mask = tf.to_float(tf.not_equal(trg_seq,0))
    trg_mask = trg_mask[:,:-1]
    trg_mask = tf.pad(trg_mask, [[0,0],[1,0]],constant_values=1)

    decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
    decoder_input *= hidden_size**0.5
    decoder_self_attention_bias = attention.attention_bias(tf.shape(decoder_input)[1], "causal")
    decoder_input = tf.pad(decoder_input, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
    decoder_input = attention.add_timing_signal(decoder_input)

    decoder_input = tf.nn.dropout(decoder_input, 1.0 - params['residual_dropout'])

    # training
    if is_training:
        decoder_output, attention_weights = transformer_decoder(decoder_input, encoder_output, \
                decoder_self_attention_bias, encoder_self_attention_bias, \
                states_key=None, states_val=None, params=params)
    # infer
    else:
        if states_key is None or states_val is None:
            raise ValueError("NoneType Unsupported: Either states_key or states_val is a NoneType Tensor.")
        
        decoder_input = decoder_input[:, -1:, :]
        decoder_self_attention_bias = decoder_self_attention_bias[:,:,-1,:]
        decoder_output, attention_weights = transformer_decoder(decoder_input, encoder_output, \
                decoder_self_attention_bias, encoder_self_attention_bias, \
                states_key=states_key, states_val=states_val, params=params)
        decoder_output_last = decoder_output[:, -1, :]
        
        if params["shared_embedding_and_softmax_weights"]:
            embedding_scope = 'Shared_Embedding' if params["shared_source_target_embedding"] \
                                                 else 'Target_Embedding'
            with tf.variable_scope(embedding_scope, reuse=True):
                weights = tf.get_variable('Weights')
        else:
            weights = tf.get_variable("Softmax", [tgt_vocab_size, hidden_size])
        
        logits = tf.matmul(decoder_output_last, weights, transpose_b=True)
        log_prob = tf.nn.log_softmax(logits)
        
        return log_prob, states_key, states_val
    
    logits = prediction(decoder_output, params)
    
    on_value = params["confidence"]
    off_value = (1.0 - params["confidence"]) / tf.to_float(trg_vocab_size - 1)
    soft_targets = tf.one_hot(tf.cast(trg_seq, tf.int32), depth=trg_vocab_size,\
                              on_value=on_value, off_value=off_value)
    mask = tf.cast(trg_mask, logits.dtype)
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=soft_targets) * mask
    loss = tf.reduce_sum(xentropy) / tf.reduce_sum(mask)
    
    return loss


def build_training_graph(features, labels, params): 
    # encode
    encoder_output, encoder_self_attention_bias = encoding_graph(features, params)
    
    #semantic_output: [batch_size, 1, hidden]
    semantic_output, mu, log_sigma = semantic_extractor(encoder_output, src_masks, params)
    
    # gated & norm
    gated_encoder_output = gated_encoder_semantic(encoder_output, semantic_output, src_masks, params)
    
    # decode
    loss = decoding_graph(gated_encoder_output, encoder_self_attention_bias, labels, params, is_training=True)

    return loss, mu, log_sigma

def prediction(decoder_output, params):
    hidden_size = params["hidden_size"]
    trg_vocab_size = params["trg_vocab_size"]

    if params["shared_embedding_and_softmax_weights"]:
        embedding_scope = 'Shared_Embedding' if params["shared_source_target_embedding"] \
                                             else 'Target_Embedding'
        with tf.variable_scope(embedding_scope, reuse=True):
            weights = tf.get_variable('Weights')
    else:
        weights = tf.get_variable("Softmax", [tgt_vocab_size, hidden_size])
    shape = tf.shape(decoder_output)[:-1]
    decoder_output = tf.reshape(decoder_output, [-1, hidden_size])
    logits = tf.matmul(decoder_output, weights, transpose_b=True)
    logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0))
    return logits

def get_initializer(params):
    if params["initializer"] == "uniform":
        max_val = params["initializer_scale"]
        return tf.random_uniform_initializer(-max_val, max_val)
    elif params["initializer"] == "normal":
        return tf.random_normal_initializer(0.0, params["initializer_scale"])
    elif params["initializer"] == "normal_unit_scaling":
        return tf.variance_scaling_initializer(params["initializer_scale"],
                                               mode="fan_avg",
                                               distribution="normal")
    elif params["initializer"] == "uniform_unit_scaling":
        return tf.variance_scaling_initializer(params["initializer_scale"],
                                               mode="fan_avg",
                                               distribution="uniform")
    else:
        raise ValueError("Unrecognized initializer: %s" % params["initializer"])

def get_learning_rate_decay(learning_rate, global_step, params):
    if params["learning_rate_decay"] in ["linear_warmup_rsqrt_decay", "noam"]:
        step = tf.to_float(global_step)
        warmup_steps = tf.to_float(params["warmup_steps"])
        multiplier = params["hidden_size"] ** -0.5
        decay = multiplier * tf.minimum((step + 1) * (warmup_steps ** -1.5),
                                        (step + 1) ** -0.5)
        return learning_rate * decay
    elif params["learning_rate_decay"] == "piecewise_constant":
        return tf.train.piecewise_constant(tf.to_int32(global_step),
                                           params["learning_rate_boundaries"],
                                           params["learning_rate_values"])
    elif params["learning_rate_decay"] == "none":
        return learning_rate
    else:
        raise ValueError("Unknown learning_rate_decay")


def cyclical_anneal_schedule(cur_step, params):
	t = tf.mod(cur_step - params['start_step'], params['per_cycle_steps']) / tf.to_float(params['per_cycle_steps'])
	beta_kl = tf.where(t < params['kl_proportion'], t/params['kl_proportion'], 1.0)
	return beta_kl


def compute_semantic_loss(mu, log_sigma, eps=1e-5):
	kld = 0.5 * tf.reduce_mean(tf.reduce_sum(-log_sigma - 1.0 + tf.exp(log_sigma) + tf.pow(mu, 2.0), axis=-1))
	return kld


def transformer_model_train_fn(features, labels, mode, params):
    initializer = get_initializer(params)
    with tf.variable_scope('NmtModel', initializer=initializer):
        if mode == tf.estimator.ModeKeys.TRAIN:
            num_gpus = params['num_gpus']
            global_step = tf.train.get_global_step()
            learning_rate = get_learning_rate_decay(params["learning_rate"],
                                                    global_step, params)
            learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
            if params['optimizer'] == 'sgd':
                optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            elif params['optimizer'] == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 
                                                   beta1=params["adam_beta1"], 
                                                   beta2=params["adam_beta2"], 
                                                   epsilon=params["adam_epsilon"])
            else:
                tf.logging.info("optimizer not supported")
                sys.exit()
            
            def fill_gpus(inputs, num_gpus):
                outputs = inputs
                for i in range(num_gpus):
                    outputs = tf.concat([outputs, inputs], axis=0)
                outputs= outputs[:num_gpus,]
                return outputs
            
            # Fu et al. Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing. NAACL-2019
            step = tf.to_float(global_step)
            beta_kld = cyclical_anneal_schedule(step, params)
            
            features = tf.cond(tf.shape(features)[0] < num_gpus, lambda: fill_gpus(features, num_gpus), lambda: features)
            labels = tf.cond(tf.shape(labels)[0] < num_gpus, lambda: fill_gpus(labels, num_gpus), lambda: labels)
            feature_shards = layers.shard_features(features, num_gpus)
            label_shards = layers.shard_features(labels, num_gpus)
            
            devices = ["gpu:%d" % d for d in range(num_gpus)]
            sharded_losses = []
            for i, device in enumerate(devices):
                with tf.variable_scope(tf.get_variable_scope(), reuse=True if i > 0 else None):
                    with tf.device(device):
                        loss, mu, log_sigma = build_training_graph(feature_shards[i], label_shards[i], params)
                        kld_loss = compute_semantic_loss(mu, log_sigma)
                        loss += beta_kld * kld_loss
                        sharded_losses.append(loss)
            
            total_loss = tf.add_n(sharded_losses) / len(sharded_losses)
            total_loss = total_loss + tf.losses.get_regularization_loss()
            
            tf.summary.scalar("loss", total_loss)

            opt = optimizers.MultiStepOptimizer(optimizer, params["update_cycle"])
            # Optimization
            grads_and_vars = opt.compute_gradients(total_loss, colocate_gradients_with_ops=True)
            
            gradient_clip_norm = params['gradient_clip_norm']
            if gradient_clip_norm > 0.0:
                grads, var_list = list(zip(*grads_and_vars))
                grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm)
                grads_and_vars = zip(grads, var_list)
            
            train_op = opt.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step())
             
            return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
