
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers

class mask_permutation(tf.keras.Model):
    def __init__(self, name):
        super(mask_permutation, self).__init__(name=name)

    def call(self, inputs, attention_mask):

        tf.config.run_functions_eagerly(True)

        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            print("*****attention mask input", inputs)

            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            mask_compromoised = tf.random.uniform(shape=(batch_size, from_seq_length-1), minval=0, maxval=1)
            mask_compromoised = tf.cast(tf.math.less_equal(mask_compromoised, 1), dtype=inputs.dtype)
            #print("***********", tf.math.reduce_sum(mask_compromoised))


            mask_cls = tf.ones(shape=(batch_size, 1), dtype=inputs.dtype)
            mask_compromoised = tf.concat([mask_cls, mask_compromoised], axis=1)
            #print("********** complete mask compromised", mask_compromoised)
            #print("******** attention mask sum along axis=1", tf.math.reduce_sum(attention_mask[:, 0, :]))
            #test_tmp = tf.multiply(mask_compromoised, attention_mask[:, 0, :])
            #print("****** sum along axis = 1", tf.math.reduce_sum(test_tmp))


            mask_compromoised = tf.reshape(mask_compromoised, [batch_size, 1, from_seq_length])
            broadcast_ones = tf.ones(
                shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)

            # Here we broadcast along two dimensions to create the mask.
            mask_compromoised = broadcast_ones * mask_compromoised
            #print("********* after broad cast", mask_compromoised)
            #print("********* after broad cast", mask_compromoised[0,0,:])
            #assert np.all(mask_compromoised[0,0,:].numpy() == mask_compromoised[0,1,:].numpy())
            #assert np.all(mask_compromoised[0, 1, :].numpy() == mask_compromoised[0, -2, :].numpy())
            #print("******** attention mask sum along axis=1", tf.math.reduce_sum(attention_mask[:, 0, :], axis=1))
            #print("*** check reall mask value", attention_mask[0,0,:])

            #print("***********", tf.math.reduce_sum(tf.math.abs(tf.math.subtract(attention_mask, tf.math.multiply(attention_mask, mask_compromoised)))))
            attention_mask = tf.math.multiply(attention_mask, mask_compromoised)
            #print("*** check after mask value", attention_mask[0,0,:])

            #print("******** attention mask sum along axis=1", tf.math.reduce_sum(attention_mask[:, 0, :], axis=1))

            return inputs, attention_mask

        else:
            return inputs, attention_mask