
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers

from tensorflow.python.ops import clustering_ops
from tensorflow.python.ops import gen_clustering_ops


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)

    def cosine_dis_similarity(self, inputs_norm, current_center):
        current_center = tf.reshape(current_center, (-1, 1))

        cosine_similarity_scores = tf.linalg.matmul(
            inputs_norm,
            current_center,
        )

        cosine_dis_similarity_scores = tf.math.subtract(
            tf.constant(1.0, dtype=inputs_norm.dtype),
            cosine_similarity_scores,
        )

        cosine_dis_similarity_scores = tf.math.divide(
            cosine_dis_similarity_scores,
            tf.constant(2.0, dtype=inputs_norm.dtype),
        )

        return tf.math.multiply(cosine_dis_similarity_scores, tf.constant(10, dtype=inputs_norm.dtype))

    @tf.function
    def fn_cluster_3(self, arg):

        inputs_norm, attention_mask_vector = arg
        sequence_len = tf.shape(inputs_norm)[0]

        center_idx = tf.constant([0], dtype=tf.int32)
        current_center = inputs_norm[0, :]

        dist_to_current_center = self.cosine_dis_similarity(inputs_norm=inputs_norm, current_center=current_center)

        dist_to_current_center = tf.reshape(dist_to_current_center, (-1,))
        nearest_center_distance = tf.math.multiply(dist_to_current_center, attention_mask_vector)

        scale = 1000
        a = tf.range(start=tf.shape(inputs_norm)[0], limit=0, delta=-1)
        a = tf.cast(tf.math.divide(a, scale), dtype=inputs_norm.dtype)
        adder = (1.0 - math_ops.cast(attention_mask_vector, inputs_norm.dtype)) * a

        nearest_center_distance += adder

        value_to_add_each_iteration = 1

        def fn(i, nearest_center_distance, center_idx):

            new_center_idxes = tf.math.top_k(
                nearest_center_distance,
                k=value_to_add_each_iteration,
                sorted=True,
            ).indices

            new_center_idxes = tf.reshape(new_center_idxes, (-1,))
            center_idx = tf.concat([center_idx, new_center_idxes], axis=0)

            current_center = inputs_norm[tf.squeeze(new_center_idxes[0]), :]

            dist_to_current_center = self.cosine_dis_similarity(
                inputs_norm=inputs_norm,
                current_center=current_center,
            )

            #tf.assert_equal(tf.shape(dist_to_current_center), (sequence_len, 1))

            mask_selected_center = tf.one_hot(
                indices=new_center_idxes,
                depth=sequence_len,
                on_value=1,
                off_value=0,
                dtype=nearest_center_distance.dtype,
            )

            mask_selected_center = tf.math.reduce_sum(mask_selected_center, axis=0, keepdims=False)
            nearest_center_distance = tf.math.multiply(nearest_center_distance, tf.math.subtract(1.0, mask_selected_center))
            nearest_center_distance = tf.reshape(nearest_center_distance, (-1, 1))

            nearest_center_distance = tf.math.reduce_min(
                tf.concat([nearest_center_distance, dist_to_current_center], axis=1),
                axis=1,
                keepdims=True,
            )
            nearest_center_distance = tf.reshape(nearest_center_distance, (-1, ))
            return [tf.math.add(i, value_to_add_each_iteration), nearest_center_distance, center_idx]

        i = tf.constant(1, dtype=tf.int32)

        condition = lambda i, a, b: i < self.k_clusters

        _, _, updated_center_idxes = tf.while_loop(
            condition,
            fn,
            [i, nearest_center_distance, center_idx],
            shape_invariants=[i.get_shape(), nearest_center_distance.get_shape(), tf.TensorShape([None])]
        )

        updated_center_idxes = updated_center_idxes[: self.k_clusters]
        #unique_idx, _ = tf.unique(updated_center_idxes)
        #tf.assert_equal(tf.shape(unique_idx), tf.shape(updated_center_idxes))
        #tf.assert_equal(tf.sort(updated_center_idxes), tf.range(0, self.k_clusters, 1))
        return updated_center_idxes


    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio, kcenters_param_s):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            assert isinstance(reduce_ratio, int)
            cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)
            self.k_clusters = from_seq_length

            inputs_norm, _ = tf.linalg.normalize(inputs, axis=2, ord=2)
            attention_mask_2D = attention_mask[:, 0, :]

            idx_sort_kmeans_plus = tf.vectorized_map(self.fn_cluster_3, (inputs_norm, attention_mask_2D))

            # Remove the only index specified by cut_off_int
            idx_sort_sig_score_removed_one = tf.concat(
                [
                    idx_sort_kmeans_plus[:, :cut_off_int],
                    idx_sort_kmeans_plus[:, tf.math.add(cut_off_int, 1):]
                ],
                axis=1,
            )

            idx_sort_kmeans_plus = tf.sort(idx_sort_sig_score_removed_one, axis=1)

            inputs_reduced = tf.gather(inputs, indices=idx_sort_kmeans_plus, batch_dims=1)

            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_kmeans_plus, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, tf.math.subtract(from_seq_length, 1)])
            broadcast_ones = tf.ones(
                shape=[batch_size, tf.math.subtract(from_seq_length, 1), 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            tf.debugging.assert_equal(
                tf.shape(attention_mask_reduced)[2], tf.math.subtract(from_seq_length, 1), message=None, summarize=None, name=None
            )
            tf.debugging.assert_equal(
                tf.shape(inputs_reduced)[1], tf.math.subtract(from_seq_length, 1), message=None, summarize=None, name=None
            )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask