
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


def _large_compatible_negative(tensor_type):
    """Large negative number as Tensor.

    This function is necessary because the standard value for epsilon
    in this module (-1e9) cannot be represented using tf.float16

    Args:
      tensor_type: a dtype to determine the type.

    Returns:
      a large negative number.
    """
    if tensor_type == dtypes.float16:
        return dtypes.float16.min
    return -1e9


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name, cos_sim_threshold):
        super(cosine_similarity_calculator, self).__init__(name=name)
        #self.num_tokens_dropped = tf.constant(0, dtype=tf.float32)
        self._cos_sim_threshold = cos_sim_threshold

    @tf.function
    def call(self, inputs, attention_mask):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            #cls_batch_idx = [[i, 0] for i in tf.range(start=0, limit=batch_size, delta=1)]
            cls_batch_idx = tf.map_fn(
                fn=lambda t: tf.stack([t, tf.constant(0)], axis=0), elems=tf.range(start=0, limit=batch_size, delta=1)
            )

            tensor_cls = tf.gather_nd(inputs, indices=cls_batch_idx, batch_dims=0)
            # assertion to make sure slicing is right
            tf.debugging.assert_equal(
                tensor_cls, inputs[:, 0, :], message=None, summarize=None, name=None
            )
            tensor_cls = tf.expand_dims(tensor_cls, axis=-1)

            inputs_norm, _ = tf.linalg.normalize(inputs, axis=2, ord=2)
            #print("**** reduce sum after normalziation", tf.math.reduce_sum(tf.math.square(inputs_norm), axis=2))

            tensor_cls_norm, _ = tf.linalg.normalize(tensor_cls, axis=1, ord=2)
            #print("**** reduce sum after normalziation", tf.math.reduce_sum(tf.math.square(tensor_cls_norm), axis=1))

            cosine_similarity_scores = tf.linalg.matmul(inputs_norm[:, 1:, :], tensor_cls_norm)
            #cosine_similarity_per_batch = tf.math.reduce_mean(cosine_similarity_scores, axis=1)
            #print("************** average cosine similarities: ", tf.math.reduce_mean(cosine_similarity_per_batch))
            cosine_similarity_scores = cosine_similarity_scores[:, :, 0]

            print("**** threshold", self._cos_sim_threshold)
            cosine_mask = tf.cast(tf.math.less(cosine_similarity_scores, self._cos_sim_threshold), dtype=inputs.dtype)
            mask_cls = tf.ones(shape=(batch_size, 1), dtype=inputs.dtype)
            cosine_mask = tf.concat([mask_cls, cosine_mask], axis=1)
            #cosine_mask[:, 0,:].assign(tf.Variable(tf.ones(shape=(1043, 1), dtype=inputs.dtype)))

            cosine_mask = tf.reshape(cosine_mask, [batch_size, 1, from_seq_length])
            broadcast_ones = tf.ones(
                shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)

            # Here we broadcast along two dimensions to create the mask.
            cosine_mask = broadcast_ones * cosine_mask
            num_before = tf.math.reduce_sum(attention_mask[:, 0, :])
            attention_mask = tf.math.multiply(attention_mask, cosine_mask)
            num_after = tf.math.reduce_sum(attention_mask[:, 0, :])
            #self.num_tokens_dropped = tf.math.add(self.num_tokens_dropped, tf.math.subtract(num_before, num_after))
            tf.print("\nnumber of tokens to drop: ", tf.math.subtract(num_before, num_after))

            #***************************************************************
            # res = cosine_similarity_scores[:, :, 0]
            # mask_2d = attention_mask[:, 0, 1:]
            # adder = (1.0 - math_ops.cast(mask_2d, inputs.dtype)) * (
            #     _large_compatible_negative(inputs.dtype))
            # res += adder
            # res = res.numpy().flatten()
            # mask = res > -1e8
            # print("layer #:", layer_name)
            # res = res[mask]
            # print("res shape", res.shape)
            # print("************* res\n", pd.DataFrame(res).to_csv(index=False, header=False))
            return inputs, attention_mask
        else:
            return inputs, attention_mask