
import tensorflow as tf
class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio, kcenters_param_s=-1):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]

            cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)

            score = tf.reshape(tf.range(cut_off_int), [1, -1])
            score = tf.repeat(score, batch_size, axis=0)

            inputs_reduced = tf.gather(inputs, indices=score, batch_dims=1)

            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=score, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # inputs_reduced = inputs[:, :cut_off_int, :]
            # attention_mask_2D = attention_mask[:, 0, :]
            # attention_mask_2D_reduced = attention_mask_2D[:, :cut_off_int]
            #
            # attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            # broadcast_ones = tf.ones(
            #     shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            # attention_mask_reduced = broadcast_ones * attention_mask_reduced
            # tf.debugging.assert_equal(
            #     tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            # )
            # tf.debugging.assert_equal(
            #     tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            # )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask