
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio, kcenters_param_s=-1):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            # tf.debugging.assert_equal(
            #     tf.shape(attention_scores), (batch_size, 12, from_seq_length, from_seq_length), message=None, summarize=None, name=None
            # )

            #print("\n**** ATT_ONLY mask ratio", reduce_ratio)
            if isinstance(reduce_ratio, float):
                cut_off_int = tf.cast(tf.scalar_mul(1-reduce_ratio, tf.cast(from_seq_length, dtype=tf.float32)), dtype=tf.int32)
            else:
                cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)
            # Sum over each column of attention score matric with shape batch size, 12, seq lengh, seq length
            significance_score = tf.math.reduce_sum(attention_scores, axis=2)

            # Extract diagonal elements of attention score
            diag_att_by_itself = tf.linalg.diag_part(attention_scores)

            # Subtract sum of attention scores of each column by attention score of itself
            significance_score = tf.math.subtract(significance_score, diag_att_by_itself)

            #tf.debugging.assert_equal(tf.shape(significance_score), (batch_size, 12, from_seq_length))
            # tf.debugging.assert_near(
            #     tf.math.reduce_sum(attention_scores[:1, :1, :, :], axis=2),
            #     significance_score[:1, :1, :],
            # )

            # Sum over 12 heads, although I saw in the implementation of PowerBert, they use mean
            significance_score = tf.math.reduce_sum(significance_score, axis=1)

            idx_sort_sig_score = tf.math.top_k(
                significance_score[:, 1:],
                k=tf.math.subtract(cut_off_int, 1),
                sorted=True,
            ).indices

            # Add the each element of idx by 1 due to 'significance_score[:, 1:]' in last step
            idx_sort_sig_score = tf.math.add(
                idx_sort_sig_score,
                tf.ones(shape=[batch_size, tf.math.subtract(cut_off_int, 1)], dtype=tf.int32)
            )

            # create index 0 for CLS token.
            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)

            # Concatenation
            idx_sort_sig_score = tf.concat([idx_cls, idx_sort_sig_score], axis=1)

            # Sort as we wanna slice the tensor based on its original sequence.
            idx_sort_sig_score = tf.sort(idx_sort_sig_score, axis=1)

            # Slice inputs tensor using idx_sort_sig_score
            inputs_reduced = tf.gather(inputs, indices=idx_sort_sig_score, batch_dims=1)

            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_sig_score, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # tf.debugging.assert_equal(
            #     tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            # )
            # tf.debugging.assert_equal(
            #     tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            # )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask