
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import layers
import tensorflow as tf


def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4):
    length = tf.shape(x)[1]
    channels = tf.shape(x)[2]
    position = tf.to_float(tf.range(length))
    num_timescales = channels // 2
    
    log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            (tf.to_float(num_timescales) - 1)
    )
    inv_timescales = min_timescale * tf.exp(
            tf.to_float(tf.range(num_timescales)) * -log_timescale_increment
    )
    
    scaled_time = tf.expand_dims(position, 1) * \
                  tf.expand_dims(inv_timescales, 0)
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
    signal = tf.reshape(signal, [1, length, channels])

    return x + tf.cast(signal, x.dtype)

def attention_bias(inputs, mode, inf=-1e9, dtype=None):
    if dtype is None:
        dtype = tf.float32

    if dtype != tf.float32:
        inf = dtype.min

    if mode =="masking":
        mask = inputs
        ret = (1.0 - mask) * inf
        ret = tf.expand_dims(tf.expand_dims(ret, 1), 1)

    elif mode =="causal":
        length = inputs
        lower_triangle = tf.matrix_band_part(
            tf.ones([length, length]), -1, 0
        )
        ret = inf * (1.0 - lower_triangle)
        ret = tf.reshape(ret, [1, 1, length, length])
    else:
        raise ValueError("Unknown mode %s" % mode)

    return tf.cast(ret, dtype)


def split_heads(x, num_heads):
    n = num_heads
    old_shape = x.get_shape().dims
    ndims = x.shape.ndims

    last = old_shape[-1]
    new_shape = old_shape[:-1] + [n] + [last // n if last else None]
    ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
    ret.set_shape(new_shape)
    perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims] 
    return tf.transpose(ret, perm)


def multiplicative_attention(q,
                          k,
                          v,
                          bias,
                          dropout_rate=0.0,
                          name=None):
    with tf.variable_scope(name, default_name="multiplicative_attention", values=[q, k, v]):
        
        logits = tf.matmul(q, k, transpose_b=True)
        if bias is not None:
            logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        return tf.matmul(weights, v), weights


def combine_heads(x):
    x = tf.transpose(x, [0, 2, 1, 3])
    old_shape = x.get_shape().dims
    a, b = old_shape[-2:]
    new_shape = old_shape[:-2] + [a * b if a and b else None]
    x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
    x.set_shape(new_shape)
    
    return x


def multihead_attention(queries,
                        memories,
                        bias,
                        key_depth,
                        value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        states_key=None,
                        states_val=None,
                        layer=0,
                        name=None):
    if key_depth % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of attention heads (%d)." % (key_size, num_heads))

    if value_depth % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[queries, memories]):
        if memories is None:
            combined = layers.linear(queries,
                                     key_depth * 2 + value_depth,
                                     True,
                                     True,
                                     scope="qkv_transform")
            q, k, v = tf.split(combined, 
                               [key_depth, key_depth, value_depth],
                               axis=2)
        else:
            q = layers.linear(queries, 
                              key_depth, 
                              True, 
                              True,
                              scope="q_transform")
            combined = layers.linear(memories,
                                     key_depth + value_depth,
                                     True,
                                     True,
                                     scope="kv_transform")
            k, v = tf.split(combined, [key_depth, value_depth], axis=2)
        
        if states_key is not None:
            k = states_key[layer] = tf.concat([states_key[layer], k], axis=1)
        if states_val is not None:
            v = states_val[layer] = tf.concat([states_val[layer], v], axis=1)
        
        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        key_depth_per_head = key_depth // num_heads
        q *= key_depth_per_head**-0.5
        
        x, w = multiplicative_attention(q, k, v, bias, dropout_rate)
        x = combine_heads(x)
        w = tf.reduce_mean(w,1)
        x = layers.linear(x, output_depth, True, True, scope="output_transform")
        return x, w

