
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name, cos_sim_threshold):
        super(cosine_similarity_calculator, self).__init__(name=name)
        self._cos_sim_threshold = cos_sim_threshold

    @tf.function
    def call(self, inputs, attention_mask):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            print("\n**** ratio to mask", self._cos_sim_threshold)
            cut_off_int = tf.cast(tf.scalar_mul(1-self._cos_sim_threshold, tf.cast(from_seq_length, dtype=tf.float32)), dtype=tf.int32)

            '''
            Below we calculate levearge scores.
            first, we calculate X^{T}*X, where X has shape (Batch size, Sequence length, hidden dimension)
            Note, tf.linalg.matmul allows batch operation
            '''
            inputs_XTX = tf.linalg.matmul(
                tf.transpose(inputs, perm=[0, 2, 1]),
                inputs
            )

            # Then calculate inverse
            inverse_XTX = tf.linalg.inv(inputs_XTX)

            # Then we calculate based on definiton of levearge score
            # https://en.wikipedia.org/wiki/Leverage_(statistics)#:~:text=9%20References-,Definition,the%20independent%20or%20explanatory%20variables
            # The inputs_proj_matrix has shape (batch size, sequence length, sequence length)
            inputs_proj_matrix = tf.linalg.matmul(
                tf.linalg.matmul(inputs, inverse_XTX),
                tf.transpose(inputs, perm=[0, 2, 1])
            )

            # we extract diagnal elements of inputs_proj_matrix, which ends up with shape (batch size, sequence length)
            leverage_scores = tf.linalg.diag_part(inputs_proj_matrix)

            # Pick up top K values of leverage score for each example,
            idx_sort_leverage = tf.math.top_k(
                leverage_scores[:, 1:], #do not consider the first cls token
                k=tf.math.subtract(cut_off_int, 1), # subtract by 1 to leave 1 token for cls token.
                sorted=True,
            ).indices

            # Add the each element of idx by 1 due to 'leverage_scores[:, 1:]' in last step
            idx_sort_leverage = tf.math.add(
                idx_sort_leverage,
                tf.ones(shape=[batch_size, tf.math.subtract(cut_off_int, 1)], dtype=tf.int32)
            )

            # create index 0 for CLS token.
            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)

            # Concatenation
            idx_sort_leverage = tf.concat([idx_cls, idx_sort_leverage], axis=1)

            # Sort as we wanna slice the tensor based on its original sequence.
            idx_sort_leverage = tf.sort(idx_sort_leverage, axis=1)

            # Slice the inputs tensor
            inputs_reduced = tf.gather(inputs, indices=idx_sort_leverage, batch_dims=1)

            # Slice the attention mask
            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_leverage, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # Sanity checking
            tf.debugging.assert_equal(
                tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            )
            tf.debugging.assert_equal(
                tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask