
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


def _large_compatible_positive(tensor_type):
    """Large negative number as Tensor.

    This function is necessary because the standard value for epsilon
    in this module (-1e9) cannot be represented using tf.float16

    Args:
      tensor_type: a dtype to determine the type.

    Returns:
      a large negative number.
    """
    if tensor_type == dtypes.float16:
        return dtypes.float16.max
    return 1e9


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio, kcenters_param_s=-1):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            inputs_norm, _ = tf.linalg.normalize(inputs, axis=2, ord=2)
            cosine_similarity_scores = tf.linalg.matmul(
                inputs_norm[:, 1:, :],
                tf.transpose(inputs_norm[:, :1, :], perm=[0, 2, 1]),
            )
            cosine_similarity_scores = cosine_similarity_scores[:, :, 0]

            assert isinstance(reduce_ratio, int)
            cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)

            attention_mask_2D_NoCls = attention_mask[:, 0, 1:]

            adder = (1.0 - math_ops.cast(attention_mask_2D_NoCls, inputs.dtype)) * (
                _large_compatible_positive(inputs.dtype))

            cosine_similarity_scores = cosine_similarity_scores + adder

            idx_sort_cos_similar = tf.math.top_k(
                tf.math.multiply(-1.0, cosine_similarity_scores),
                k=tf.shape(cosine_similarity_scores)[1],
                sorted=True,
            ).indices

            idx_sort_cos_similar = tf.math.add(
                idx_sort_cos_similar,
                tf.ones(shape=[batch_size, tf.shape(cosine_similarity_scores)[1]], dtype=tf.int32)
            )

            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)
            idx_sort_cos_similar = tf.concat([idx_cls, idx_sort_cos_similar], axis=1)

            idx_sort_cos_similar_removed_one = tf.concat(
                [
                    idx_sort_cos_similar[:, :cut_off_int],
                    idx_sort_cos_similar[:, tf.math.add(cut_off_int, 1):]
                ],
                axis=1,
            )

            attention_score_idx = idx_sort_cos_similar[:, cut_off_int:tf.math.add(cut_off_int, 1)]
            idx_sort_cos_similar = tf.sort(idx_sort_cos_similar_removed_one, axis=1)

            inputs_reduced = tf.gather(inputs, indices=idx_sort_cos_similar, batch_dims=1)
            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_cos_similar, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, tf.math.subtract(from_seq_length, 1)])
            broadcast_ones = tf.ones(
                shape=[batch_size, tf.math.subtract(from_seq_length, 1), 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            attention_scores_combined_head = tf.math.reduce_mean(attention_scores[:, :, 0, :], axis=1)
            attention_multipler = tf.gather(attention_scores_combined_head, indices=attention_score_idx, batch_dims=1)
            attention_multipler = tf.math.add(attention_multipler, 1)
            multiplier_one = tf.ones(
                shape=(batch_size, tf.math.subtract(from_seq_length, 2)),
                dtype=inputs.dtype,
            )
            attention_multipler_matrix = tf.concat([attention_multipler, multiplier_one], axis=1)
            attention_multipler_matrix = tf.reshape(attention_multipler_matrix, [batch_size, tf.math.subtract(from_seq_length, 1), 1])
            inputs_reduced = attention_multipler_matrix * inputs_reduced

            #print("inputs reduced", inputs_reduced)
            tf.debugging.assert_equal(
                tf.shape(attention_mask_reduced)[2], tf.math.subtract(from_seq_length, 1), message=None, summarize=None, name=None
            )
            tf.debugging.assert_equal(
                tf.shape(inputs_reduced)[1], tf.math.subtract(from_seq_length, 1), message=None, summarize=None, name=None
            )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask