
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers


def _large_compatible_negative(tensor_type):
    """Large negative number as Tensor.

    This function is necessary because the standard value for epsilon
    in this module (-1e9) cannot be represented using tf.float16

    Args:
      tensor_type: a dtype to determine the type.

    Returns:
      a large negative number.
    """
    if tensor_type == dtypes.float16:
        return dtypes.float16.min
    return -1e9


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            inputs_norm, _ = tf.linalg.normalize(inputs, axis=2, ord=2)
            cosine_similarity_scores = tf.linalg.matmul(
                inputs_norm[:, 1:, :],
                tf.transpose(inputs_norm[:, :1, :], perm=[0, 2, 1]),
            )
            cosine_similarity_scores = cosine_similarity_scores[:, :, 0]

            if isinstance(reduce_ratio, float):
                cut_off_int = tf.cast(tf.scalar_mul(1-reduce_ratio, tf.cast(from_seq_length, dtype=tf.float32)), dtype=tf.int32)
            else:
                cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)

            cosine_similarity_scores = tf.math.subtract(
                tf.constant(1.0, dtype=inputs.dtype),
                cosine_similarity_scores,
            )

            cosine_similarity_scores = tf.math.divide(
                cosine_similarity_scores,
                tf.constant(2.0, dtype=inputs.dtype),
            )

            # rescale cosine_similarity_scores
            largest_value_cosine = tf.reduce_max(cosine_similarity_scores, axis=1, keepdims=True)
            cosine_similarity_scores = tf.divide(cosine_similarity_scores, largest_value_cosine)

            ##################### with attention score ####################
            significance_score = tf.math.reduce_sum(attention_scores, axis=2)

            # Extract diagonal elements of attention score
            diag_att_by_itself = tf.linalg.diag_part(attention_scores)

            # Subtract sum of attention scores of each column by attention score of itself
            significance_score = tf.math.subtract(significance_score, diag_att_by_itself)

            significance_score = tf.math.divide(
                significance_score,
                tf.cast(tf.math.subtract(from_seq_length, tf.constant(1, dtype=tf.int32)), dtype=inputs.dtype)
            )

            #tf.debugging.assert_equal(tf.shape(significance_score), (batch_size, 12, from_seq_length))

            # Sum over 12 heads, although I saw in the implementation of PowerBert, they use mean
            significance_score = tf.math.reduce_mean(significance_score, axis=1)[:, 1:]

            # scale significance_score
            largest_value_sig_score = tf.math.reduce_max(significance_score, axis=1, keepdims=True)
            significance_score = tf.divide(significance_score, largest_value_sig_score)

            ################## ensemble #######################
            ensemble_scores = tf.math.add(cosine_similarity_scores, significance_score)

            #ensemble_scores = tf.divide(ensemble_scores, tf.constant(2.0, dtype=inputs.dtype))

            attention_mask_2D_NoCls = attention_mask[:, 0, 1:]

            adder = (1.0 - math_ops.cast(attention_mask_2D_NoCls, inputs.dtype)) * (
                _large_compatible_negative(inputs.dtype))

            ensemble_scores = ensemble_scores + adder

            # #******************************************************* Should delete it in official code ***********
            #
            # hist_cosine_score = (cosine_similarity_scores + adder).numpy()
            # hist_sig_score = (significance_score[:, 1:] + adder).numpy()
            #
            # hist_cosine_score_flat = hist_cosine_score.flatten()
            # hist_sig_score_flat = hist_sig_score.flatten()
            #
            # hist_cosine_score_flat_reduce = hist_cosine_score_flat[hist_cosine_score_flat != -1e9]
            # hist_sig_score_flat_reduce = hist_sig_score_flat[hist_sig_score_flat != -1e9]
            #
            # print("**** cosine score \n", pd.DataFrame(hist_cosine_score_flat_reduce).to_csv())
            # print("****** sig score \n", pd.DataFrame(hist_sig_score_flat_reduce).to_csv())
            # #****************************************************************


            idx_sort_cos_similar = tf.math.top_k(
                ensemble_scores,
                k=tf.math.subtract(cut_off_int, 1),
                sorted=True,
            ).indices

            idx_sort_cos_similar = tf.math.add(
                idx_sort_cos_similar,
                tf.ones(shape=[batch_size, tf.math.subtract(cut_off_int, 1)], dtype=tf.int32)
            )

            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)
            idx_sort_cos_similar = tf.concat([idx_cls, idx_sort_cos_similar], axis=1)
            idx_sort_cos_similar = tf.sort(idx_sort_cos_similar, axis=1)

            inputs_reduced = tf.gather(inputs, indices=idx_sort_cos_similar, batch_dims=1)

            attention_mask_2D = attention_mask[:, 0, :]
            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_cos_similar, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # tf.debugging.assert_equal(
            #     tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            # )
            # tf.debugging.assert_equal(
            #     tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            # )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask