# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import sys


def split_heads(x, heads):
    """Split x into different heads, and transpose the resulting value.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
  x: A tensor with shape [batch_size, length, num_units]
Returns:
  A tensor with shape [batch_size, num_heads, length, num_units/num_heads]
"""
    with tf.name_scope("split_heads"):
        batch_size = tf.shape(x)[0]
        length = tf.shape(x)[1]
        num_units = tf.shape(x)[-1]
        # Calculate depth of last dimension after it has been split.
        depth = num_units // heads

        # Split the last dimension
        x = tf.reshape(x, [batch_size, length, heads, depth])

        # Transpose the result
        return tf.transpose(x, [0, 2, 1, 3])


def combine_heads(x, num_units):
    """Combine tensor that has been split.
Args:
  x: A tensor [batch_size, num_heads, length, num_units/num_heads]
Returns:
  A tensor with shape [batch_size, length, num_units]
"""
    with tf.name_scope("combine_heads"):
        batch_size = tf.shape(x)[0]
        length = tf.shape(x)[2]
        x = tf.transpose(x, [0, 2, 1, 3])  # --> [batch, length, num_heads, depth]
        return tf.reshape(x, [batch_size, length, num_units])


def scaled_dot_product_attention(q, k, v, mask, dropout=0, scale=None):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.

    Args:
      q: query shape == (..., seq_len_q, depth)
      k: key shape == (..., seq_len_k, depth)
      v: value shape == (..., seq_len_v, depth_v)
      mask: Float tensor with shape broadcastable
            to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
      output, attention_weights
    """

    # (..., seq_len_q, seq_len_k)
    matmul_qk = tf.matmul(q, k, transpose_b=True)

    # scale matmul_qk
    if scale is None:
        scale = tf.cast(tf.shape(k)[-1], tf.float32) ** -0.5
    scaled_attention_logits = matmul_qk * tf.cast(scale, matmul_qk.dtype)

    # add the mask to the scaled tensor.

    if mask is not None:
        scaled_attention_logits += tf.cast((mask * -1e9), matmul_qk.dtype)

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.

    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
    if dropout != 0:
        attention_weights = tf.nn.dropout(attention_weights, rate=dropout)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights


class Attention(tf.keras.layers.Layer):
    """Multi-headed attention layer."""

    def __init__(self, num_units, num_heads, dropout):
        """Initialize Attention.
    Args:
      num_units: int, output dim of hidden layer.
      num_heads: int, number of heads to repeat the same attention structure.
      attention_dropout: float, dropout rate inside attention for training.
    """
        if num_units % num_heads:
            raise ValueError(
                "Hidden size ({}) must be divisible by the number of heads ({}).".format(num_units, num_heads)
            )

        super(Attention, self).__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout = dropout
        self.attention_weights = 0

    def build(self, input_shape):
        """Builds the layer."""
        # Layers for linearly projecting the queries, keys, and values.
        self.q_dense_layer = tf.keras.layers.Dense(self.num_units, use_bias=False, name="q")
        self.k_dense_layer = tf.keras.layers.Dense(self.num_units, use_bias=False, name="k")
        self.v_dense_layer = tf.keras.layers.Dense(self.num_units, use_bias=False, name="v")
        self.output_dense_layer = tf.keras.layers.Dense(self.num_units, use_bias=False, name="output_transform")

        super(Attention, self).build(input_shape)

    def get_config(self):
        return {
            "num_units": self.num_units,
            "num_heads": self.num_heads,
            "dropout": self.dropout,
        }

    def call(self, x, y, bias, training, cache=None, scale=None, **kwargs):
        """Apply attention mechanism to x and y.
    Args:
      x: a tensor with shape [batch_size, length_x, num_units]
      y: a tensor with shape [batch_size, length_y, num_units]
      bias: attention bias that will be added to the result of the dot product.
      training: boolean, whether in training mode or not.
      cache: (Used during prediction) dictionary with tensors containing results
        of previous attentions. The dictionary must have the items:
            {"k": tensor with shape [batch_size, i, key_channels],
             "v": tensor with shape [batch_size, i, value_channels]}
        where i is the current decoded length.
    Returns:
      Attention layer output with shape [batch_size, length_x, num_units]
    """
        # Linearly project the query (q), key (k) and value (v) using different
        # learned projections. This is in preparation of splitting them into
        # multiple heads. Multi-head attention uses multiple queries, keys, and
        # values rather than regular attention (which uses a single q, k, v).
        # padding_bias = tf.expand_dims(
        #     tf.cast(tf.not_equal(tf.reduce_sum(x, -1), 0), tf.float32), -1)
        # if len(x) > 1:
        #     print("two stream attention")
        #     x, x_2nd = x
        q = self.q_dense_layer(x)
        k = self.k_dense_layer(y)
        v = self.v_dense_layer(y)

        q = split_heads(q, self.num_heads)
        k = split_heads(k, self.num_heads)
        v = split_heads(v, self.num_heads)
        if cache is not None:
            # Combine cached keys and values with new keys and values.
            k = tf.concat((cache["k"], k), axis=2)
            v = tf.concat((cache["v"], v), axis=2)
            # Update cache
            cache["k"] = k
            cache["v"] = v
        if training:
            attention_output, self.attention_weights = scaled_dot_product_attention(q, k, v, bias, self.dropout, scale)
        else:
            attention_output, self.attention_weights = scaled_dot_product_attention(q, k, v, bias, scale)
        attention_output = combine_heads(attention_output, self.num_units)
        attention_output = self.output_dense_layer(attention_output)
        return attention_output

    def get_attention_weights(self):
        return self.attention_weights


class SelfAttention(Attention):
    """Multiheaded self-attention layer."""

    def call(self, x, bias, training, cache=None, **kwargs):
        # if len(x) > 1:
        #     return super(SelfAttention, self).call(x, x[0], bias, training, cache,
        #                                            **kwargs)
        #
        # else:
        return super(SelfAttention, self).call(x, x, bias, training, cache, **kwargs)
