
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers

class token_reducer(tf.keras.Model):
    def __init__(self, name):
        super(token_reducer, self).__init__(name=name)


    def get_indexes_mask_current_attention_mask(self, attention_mask, batch_size):

        idxes_mask = []
        for i in np.arange(batch_size):
            idxes_mask.append(tf.where(attention_mask[i, 0, :] == 1)[-1])

        return idxes_mask

    def get_attention_mask_by_example(self, example_idx, attention_mask, indexes_to_merge):
        idx_zeros = tf.where(attention_mask[example_idx, 0, :] == 0)

        try:
            idx_first_zero = tf.cast(idx_zeros[0], dtype=tf.int32)
        except:
            idx_first_zero = tf.constant([100000], dtype=tf.int32)
            #print("************** out of bound", tf.math.less(indexes_to_merge[example_idx:(example_idx+1), :], idx_first_zero))

        #return tf.cast(tf.math.less(indexes_to_merge[example_idx:(example_idx+1), :], idx_first_zero), dtype=tf.float32)

        return tf.math.less(indexes_to_merge[example_idx:(example_idx+1), :], idx_first_zero)

        # print("*************** what is indexes", indexes_to_merge[example_idx:(example_idx+1), :])
        # print("*************** what is idx_fist_zero", idx_first_zero)
        # print("**************** what is result", result)
        # return result

    def merge_token_by_example(self, example_idx, inputs, indexes_to_merge, slide_window_size, seq_length_after_merged):
        merged_tensor_by_example = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, infer_shape=False)
        i = 0
        assert indexes_to_merge.shape[1] == seq_length_after_merged
        for j in range(indexes_to_merge.shape[1]):
            m = indexes_to_merge[example_idx, j]
            sliced_tensor = inputs[example_idx, m: (m + slide_window_size), :]
            tmp_reduced_tensor = tf.math.reduce_max(sliced_tensor, axis=0, keepdims=True)
            merged_tensor_by_example = merged_tensor_by_example.write(i, tmp_reduced_tensor)
            i += 1
        merged_tensor_by_example = merged_tensor_by_example.stack()
        #print("********* before reshaoe", merged_tensor_by_example)
        merged_tensor_by_example = tf.reshape(merged_tensor_by_example, [seq_length_after_merged, 1, 768])
        merged_tensor_by_example = tf.transpose(merged_tensor_by_example, perm=(1, 0, 2))
        return merged_tensor_by_example

    def call(self, inputs, attention_mask):

        """

        :param inputs:
        :return:
        """
        tf.config.run_functions_eagerly(True)

        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            #print("*********** what is attenton mastk from [previous layer", attention_mask)
            batch_size = attention_mask.shape[0]
            seq_length = attention_mask.shape[1]
            hidden_dim = inputs.shape[2]
            slide_window_size = 2

            indexes_to_merge = [np.arange(0, seq_length, slide_window_size).tolist()] * batch_size
            assert len(indexes_to_merge) == batch_size
            seq_length_after_merged = np.max([len(x) for x in indexes_to_merge])
            indexes_to_merge = tf.constant(indexes_to_merge, dtype=tf.int32)

            merged_inputs = []
            attention_mask_next_layer = []
            for i in range(batch_size):
                attention_mask_by_example = self.get_attention_mask_by_example(
                    example_idx=i,
                    attention_mask=attention_mask,
                    indexes_to_merge=indexes_to_merge,
                )

                merged_tensor_by_example = self.merge_token_by_example(
                    example_idx=i,
                    inputs=inputs,
                    indexes_to_merge=indexes_to_merge,
                    slide_window_size=slide_window_size,
                    seq_length_after_merged=seq_length_after_merged,
                )
                merged_inputs.append(merged_tensor_by_example)
                attention_mask_next_layer.append(attention_mask_by_example)

            merged_inputs = tf.concat(merged_inputs, axis=0)
            print("*********** mreged input shape", merged_inputs.shape)
            attention_mask_next_layer = tf.concat(attention_mask_next_layer, axis=0)

            #print("********* what is attention mask next layer", attention_mask_next_layer)
            assert merged_inputs.shape == (batch_size, seq_length_after_merged, hidden_dim)
            return merged_inputs, layers.SelfAttentionMask()(merged_inputs, attention_mask_next_layer)

        else:
            return inputs, attention_mask