import tensorflow as tf
import tensorflow_hub as hub
from keras import backend as K
from keras.engine.topology import Layer
from keras import initializers


class Att_Hawkes(Layer):

    def __init__(self, **kwargs):
        self.init = initializers.get('normal')
        self.epsilon_init = 0.01
        self.beta_init = 0.001
        super(Att_Hawkes, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = K.variable(self.init((input_shape[0][-1],1)))
        self.epsilon = K.variable(self.epsilon_init, name='{}_epsilon'.format(self.name))  # How important is the decay
        self.beta = K.variable(self.beta_init, name='{}_beta'.format(self.name)) # How fast is the decay
        self.trainable_weights = [self.W, self.epsilon, self.beta]
        super(Att_Hawkes, self).build(input_shape)  

    def call(self, x, mask=None):
        x1 = x[0]
        x2 = x[1]
        w1 = K.squeeze(K.squeeze(K.dot(x1, K.expand_dims(self.W, axis=0)), axis=2), axis=2)   # K.tanh()
        ai = K.exp(w1)
        w = ai/K.expand_dims(K.sum(ai, axis=1), axis=1)
        decay = self.epsilon * K.exp(-self.beta * x2)
        attention_out = x1 * K.expand_dims(w, axis=2)
        pos = K.greater(attention_out, 0)
        hawkes_out = attention_out + attention_out * K.cast(pos, dtype='float32') * K.expand_dims(decay, axis=2)

        return K.sum(hawkes_out, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][-1])


class Att_Hawkes_user(Layer):

    def __init__(self, **kwargs):
        self.init = initializers.get('normal')
        self.epsilon_init = initializers.Constant(value=0.1)
        self.beta_init = initializers.Constant(value=0.01)
        super(Att_Hawkes_user, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = K.variable(self.init((input_shape[0][-1],1)))
        self.epsilon = K.variable(self.epsilon_init((1,input_shape[2][-1])), name='{}_epsilon'.format(self.name))  # How important is the decay
        self.beta = K.variable(self.beta_init((1,input_shape[2][-1])), name='{}_beta'.format(self.name)) # How fast is the decay
        self.trainable_weights = [self.W, self.epsilon, self.beta]
        super(Att_Hawkes_user, self).build(input_shape) 

    def call(self, x, mask=None):
        x1 = x[0]
        x2 = x[1]
        x3 = x[2]
        w1 = K.squeeze(K.squeeze(K.dot(x1, K.expand_dims(self.W, axis=0)), axis=2), axis=2)
        # w1 = K.tanh(K.dot(x, self.W))
        ai = K.exp(w1)
        w = ai/K.expand_dims(K.sum(ai, axis=1), axis=1)
        eps = K.squeeze(K.squeeze(K.dot(K.expand_dims(self.epsilon,axis=0), K.expand_dims(x3,axis=2)), axis=0), axis=0)
        bet = K.squeeze(K.squeeze(K.dot(K.expand_dims(self.beta,axis=0), K.expand_dims(x3,axis=2)), axis=0), axis=0)
        decay = eps * K.exp(-bet * x2)
        attention_out = x1 * K.expand_dims(w, axis=2)
        pos = K.greater(attention_out, 0)
        hawkes_out = attention_out + attention_out * K.cast(pos, dtype='float32') * K.expand_dims(decay, axis=2)

        return K.sum(hawkes_out, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][-1])


class Elmo_Layer(Layer):
    def __init__(self, **kwargs):
        self.dimensions = 1024
        self.trainable=True
        super(Elmo_Layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.elmo = hub.Module('https://tfhub.dev/google/elmo/2', trainable=self.trainable,
                               name="{}_module".format(self.name))

        self.trainable_weights += K.tf.trainable_variables(scope="^{}_module/.*".format(self.name))
        super(Elmo_Layer, self).build(input_shape)

    def call(self, x, mask=None):
        result = self.elmo(K.squeeze(K.cast(x, tf.string), axis=1),
                      as_dict=True,
                      signature='default',
                      )['default']
        return result

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.dimensions)