import keras.backend as K
import os
from model.scaled_dot_attention import ScaledDotProductAttention

if 'TF_KERAS' in os.environ and os.environ['TF_KERAS'] != '0':
    from tensorflow.python import keras
    TF_KERAS = True
else:
    import keras
    TF_KERAS = False


class MultiHeadAttention(keras.layers.Layer):

	def __init__(self,
		head_num,
		activation='relu',
		use_bias=True,
		kernel_initializer='glorot_normal',
		bias_initializer='zeros',
		kernel_regularizer=None,
		bias_regularizer=None,
		kernel_constraint=None,
		bias_constraint=None,
		history_only=False,
		**kwargs):
		"""Initialize the layer.
		
		:param head_num: Number of heads.
		:param activation: Activations for linear mappings.
		:param use_bias: Whether to use bias term.
		:param kernel_initializer: Initializer for linear mappings.
		:param bias_initializer: Initializer for linear mappings.
		:param kernel_regularizer: Regularizer for linear mappings.
		:param bias_regularizer: Regularizer for linear mappings.
		:param kernel_constraint: Constraints for linear mappings.
		:param bias_constraint: Constraints for linear mappings.
		:param history_only: Whether to only use history in attention layer.
		
		"""

		self.supports_masking = True
		self.head_num = head_num
		self.activation = keras.activations.get(activation)
		self.use_bias = use_bias
		self.kernel_initializer = keras.initializers.get(kernel_initializer)
		self.bias_initializer = keras.initializers.get(bias_initializer)
		self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
		self.bias_regularizer = keras.regularizers.get(bias_regularizer)
		self.kernel_constraint = keras.constraints.get(kernel_constraint)
		self.bias_constraint = keras.constraints.get(bias_constraint)
		self.history_only = history_only

		self.Wq, self.Wk, self.Wv, self.Wo = None, None, None, None
		self.bq, self.bk, self.bv, self.bo = None, None, None, None
		super(MultiHeadAttention, self).__init__(**kwargs)

	def get_config(self):
		config = {
		    'head_num': self.head_num,
		    'activation': keras.activations.serialize(self.activation),
		    'use_bias': self.use_bias,
		    'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
		    'bias_initializer': keras.initializers.serialize(self.bias_initializer),
		    'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
		    'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
		    'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
		    'bias_constraint': keras.constraints.serialize(self.bias_constraint),
		    'history_only': self.history_only,
		}
		base_config = super(MultiHeadAttention, self).get_config()
		return dict(list(base_config.items()) + list(config.items()))

	def compute_output_shape(self, input_shape):
                if isinstance(input_shape, list):
                        q, k, v = input_shape
                        return q[:-1] + (v[-1],)
                return [input_shape, (input_shape[0], input_shape[1], input_shape[1])]

	def compute_mask(self, inputs, input_mask=None):
		if isinstance(input_mask, list):
		    return input_mask[0]
		return input_mask

	def build(self, input_shape):
		if isinstance(input_shape, list):
		    q, k, v = input_shape
		else:
		    q = k = v = input_shape
		feature_dim = int(v[-1])
		if feature_dim % self.head_num != 0:
		    raise IndexError('Invalid head number %d with the given input dim %d' % (self.head_num, feature_dim))
		self.Wq = self.add_weight(
		    shape=(int(q[-1]), feature_dim),
		    initializer=self.kernel_initializer,
		    regularizer=self.kernel_regularizer,
		    constraint=self.kernel_constraint,
		    name='%s_Wq' % self.name,
		)
		if self.use_bias:
		    self.bq = self.add_weight(
			shape=(feature_dim,),
			initializer=self.bias_initializer,
			regularizer=self.bias_regularizer,
			constraint=self.bias_constraint,
			name='%s_bq' % self.name,
		    )
		self.Wk = self.add_weight(
		    shape=(int(k[-1]), feature_dim),
		    initializer=self.kernel_initializer,
		    regularizer=self.kernel_regularizer,
		    constraint=self.kernel_constraint,
		    name='%s_Wk' % self.name,
		)
		if self.use_bias:
		    self.bk = self.add_weight(
			shape=(feature_dim,),
			initializer=self.bias_initializer,
			regularizer=self.bias_regularizer,
			constraint=self.bias_constraint,
			name='%s_bk' % self.name,
		    )
		self.Wv = self.add_weight(
		    shape=(int(v[-1]), feature_dim),
		    initializer=self.kernel_initializer,
		    regularizer=self.kernel_regularizer,
		    constraint=self.kernel_constraint,
		    name='%s_Wv' % self.name,
		)
		if self.use_bias:
		    self.bv = self.add_weight(
			shape=(feature_dim,),
			initializer=self.bias_initializer,
			regularizer=self.bias_regularizer,
			constraint=self.bias_constraint,
			name='%s_bv' % self.name,
		    )
		self.Wo = self.add_weight(
		    shape=(feature_dim, feature_dim),
		    initializer=self.kernel_initializer,
		    regularizer=self.kernel_regularizer,
		    constraint=self.kernel_constraint,
		    name='%s_Wo' % self.name,
		)
		if self.use_bias:
		    self.bo = self.add_weight(
			shape=(feature_dim,),
			initializer=self.bias_initializer,
			regularizer=self.bias_regularizer,
			constraint=self.bias_constraint,
			name='%s_bo' % self.name,
		    )
		super(MultiHeadAttention, self).build(input_shape)

	@staticmethod
	def _reshape_to_batches(x, head_num):
		input_shape = K.shape(x)
		batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
		head_dim = feature_dim // head_num
		x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
		x = K.permute_dimensions(x, [0, 2, 1, 3])
		return K.reshape(x, (batch_size * head_num, seq_len, head_dim))

	@staticmethod
	def _reshape_from_batches(x, head_num):
		input_shape = K.shape(x)
		batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
		x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
		x = K.permute_dimensions(x, [0, 2, 1, 3])
		return K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num))

	@staticmethod
	def _reshape_mask(mask, head_num):
		if mask is None:
		    return mask
		seq_len = K.shape(mask)[1]
		mask = K.expand_dims(mask, axis=1)
		mask = K.tile(mask, [1, head_num, 1])
		return K.reshape(mask, (-1, seq_len))

	def call(self, inputs, mask=None):
                if isinstance(inputs, list):
                    q, k, v = inputs
                else:
                    q = k = v = inputs

                if isinstance(mask, list):
                    q_mask, k_mask, v_mask = mask
                else:
                    q_mask = k_mask = v_mask = mask

                q = K.dot(q, self.Wq)
                k = K.dot(k, self.Wk)
                v = K.dot(v, self.Wv)
                if self.use_bias:
                    q += self.bq
                    k += self.bk
                    v += self.bv
                if self.activation is not None:
                    q = self.activation(q)
                    k = self.activation(k)
                    v = self.activation(v)

                [y, atten] = ScaledDotProductAttention(
                    history_only=self.history_only,
                    name='%s-Attention' % self.name,
                )(
                    inputs=[
                        self._reshape_to_batches(q, self.head_num),
                        self._reshape_to_batches(k, self.head_num),
                        self._reshape_to_batches(v, self.head_num),
                    ],
                    mask=[
                        self._reshape_mask(q_mask, self.head_num),
                        self._reshape_mask(k_mask, self.head_num),
                        self._reshape_mask(v_mask, self.head_num),
                    ],
                )
                y = self._reshape_from_batches(y, self.head_num)
                y = K.dot(y, self.Wo)
                if self.use_bias:
                    y += self.bo
                if self.activation is not None:
                    y = self.activation(y)
                if TF_KERAS:
                        # Add shape information to tensor when using `tf.keras`
                        input_shape = [K.int_shape(q), K.int_shape(k), K.int_shape(v)]
                        output_shape = self.compute_output_shape(input_shape)
                        if output_shape[1] is not None:
                                output_shape = (-1,) + output_shape[1:]
                                y = K.reshape(y, output_shape)
                return [y, atten]
