
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


def _large_compatible_positive(tensor_type):
    """Large negative number as Tensor.

    This function is necessary because the standard value for epsilon
    in this module (-1e9) cannot be represented using tf.float16

    Args:
      tensor_type: a dtype to determine the type.

    Returns:
      a large negative number.
    """
    if tensor_type == dtypes.float16:
        return dtypes.float16.max
    return 1e9


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio, kcenters_param_s=-1):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            # inputs_norm, _ = tf.linalg.normalize(inputs, axis=2, ord=2)
            # cosine_similarity_scores = tf.linalg.matmul(
            #     inputs_norm[:, 1:, :],
            #     tf.transpose(inputs_norm[:, :1, :], perm=[0, 2, 1]),
            # )
            # cosine_similarity_scores = cosine_similarity_scores[:, :, 0]
            # print("original cosine score?", cosine_similarity_scores)
            # print("less than zero?", tf.where(cosine_similarity_scores<0))


            cosine_similarity_scores = tf.random.uniform(
                shape=[batch_size, tf.math.subtract(from_seq_length, 1)],
                minval=-1, maxval=1, dtype=inputs.dtype
            )
            #print("what are random values?", cosine_similarity_scores)

            if isinstance(reduce_ratio, float):
                cut_off_int = tf.cast(tf.scalar_mul(1-reduce_ratio, tf.cast(from_seq_length, dtype=tf.float32)), dtype=tf.int32)
            else:
                cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)

            attention_mask_2D_NoCls = attention_mask[:, 0, 1:]

            adder = (1.0 - math_ops.cast(attention_mask_2D_NoCls, inputs.dtype)) * (
                _large_compatible_positive(inputs.dtype))

            cosine_similarity_scores = cosine_similarity_scores + adder

            idx_sort_cos_similar = tf.math.top_k(
                tf.math.multiply(-1.0, cosine_similarity_scores),
                k=tf.math.subtract(cut_off_int, 1),
                sorted=True,
            ).indices

            idx_sort_cos_similar = tf.math.add(
                idx_sort_cos_similar,
                tf.ones(shape=[batch_size, tf.math.subtract(cut_off_int, 1)], dtype=tf.int32)
            )

            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)
            idx_sort_cos_similar = tf.concat([idx_cls, idx_sort_cos_similar], axis=1)
            idx_sort_cos_similar = tf.sort(idx_sort_cos_similar, axis=1)

            inputs_reduced = tf.gather(inputs, indices=idx_sort_cos_similar, batch_dims=1)

            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_cos_similar, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # tf.debugging.assert_equal(
            #     tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            # )
            # tf.debugging.assert_equal(
            #     tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            # )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask