
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import dtypes
import tensorflow.keras.backend as K
from official.nlp.keras_nlp import layers

from tensorflow.python.ops import clustering_ops
from tensorflow.python.ops import gen_clustering_ops


class cosine_similarity_calculator(tf.keras.Model):

    def __init__(self, name):
        super(cosine_similarity_calculator, self).__init__(name=name)


    def fn_cluster(self, inputs):
        # print(inputs.shape)
        # kmeans = tf.compat.v1.estimator.experimental.KMeans(
        #     num_clusters=self._num_cluster,
        #     initial_clusters=clustering_ops.KMEANS_PLUS_PLUS_INIT,
        #     distance_metric=clustering_ops.COSINE_DISTANCE,
        #     use_mini_batch=False,
        # )
        #
        # kmeans.train(self.input_fn(inputs))
        inputs_noCLS = inputs[1:, :]

        res_centers = gen_clustering_ops.kmeans_plus_plus_initialization(
            points=inputs_noCLS,
            num_to_sample=tf.cast(tf.subtract(self.cut_off_int, 1), dtype=tf.int64),
            seed=0,
            num_retries_per_sample=2,
        )

        inputs_norm, _ = tf.linalg.normalize(inputs_noCLS, axis=1, ord=2)
        centers_norm, _ = tf.linalg.normalize(res_centers, axis=1, ord=2)

        _dot_prod = tf.linalg.matmul(inputs_norm, centers_norm, transpose_b=True)
        # tf.debugging.assert_equal(
        #     tf.shape(_dot_prod),
        #     (127, tf.math.subtract(self._cut_off_int, 1)))

        idx_top_1 = tf.math.top_k(tf.transpose(_dot_prod), k=1, sorted=True).indices
        # tf.debugging.assert_equal(
        #     tf.shape(idx_top_1), (tf.math.subtract(self._cut_off_int, 1), 1),
        # )
        return idx_top_1



    @tf.function
    def call(self, inputs, attention_mask, attention_scores, reduce_ratio):
        #tf.config.run_functions_eagerly(True)
        if inputs.shape[1] is not None and inputs.shape[1] > 2:
            from_shape = tf.shape(inputs)
            batch_size = from_shape[0]
            from_seq_length = from_shape[1]

            if isinstance(reduce_ratio, float):
                self.cut_off_int = tf.cast(tf.scalar_mul(1-reduce_ratio, tf.cast(from_seq_length, dtype=tf.float32)), dtype=tf.int32)
            else:
                self.cut_off_int = tf.constant(reduce_ratio, dtype=tf.int32)

            idx_sort_kmeans_plus = tf.vectorized_map(self.fn_cluster, inputs)[:, :, 0]


            idx_sort_kmeans_plus = tf.math.add(
                idx_sort_kmeans_plus,
                tf.ones(shape=[batch_size, tf.math.subtract(self.cut_off_int, 1)], dtype=tf.int32)
            )

            # create index 0 for CLS token.
            idx_cls = tf.zeros(shape=(batch_size, 1), dtype=tf.int32)

            # Concatenation
            idx_sort_kmeans_plus = tf.concat([idx_cls, idx_sort_kmeans_plus], axis=1)

            # Sort as we wanna slice the tensor based on its original sequence.
            idx_sort_kmeans_plus = tf.sort(idx_sort_kmeans_plus, axis=1)


            # print("waht is the idx selected", idx_sort_kmeans_plus[0, :])
            # print("waht is the idx selected, the second obersvation", idx_sort_kmeans_plus[1, :])
            # print("waht is the idx selected, the third obersvation", idx_sort_kmeans_plus[2, :])
            # print("waht is the idx selected, the fourth obersvation", idx_sort_kmeans_plus[3, :])


            inputs_reduced = tf.gather(inputs, indices=idx_sort_kmeans_plus, batch_dims=1)

            attention_mask_2D = attention_mask[:, 0, :]

            # print("what is attention mask 2D", attention_mask_2D[0, :])
            # print("what is attention mask 2D, the second observation", attention_mask_2D[1, :])
            # print("what is attention mask 2D, the third observation", attention_mask_2D[2, :])
            # print("what is attention mask 2D, the fourth observation", attention_mask_2D[3, :])


            attention_mask_2D_reduced = tf.gather(attention_mask_2D, indices=idx_sort_kmeans_plus, batch_dims=1)

            attention_mask_reduced = tf.reshape(attention_mask_2D_reduced, [batch_size, 1, self.cut_off_int])
            broadcast_ones = tf.ones(
                shape=[batch_size, self.cut_off_int, 1], dtype=inputs.dtype)
            attention_mask_reduced = broadcast_ones * attention_mask_reduced

            # tf.debugging.assert_equal(
            #     tf.shape(attention_mask_reduced)[2], cut_off_int, message=None, summarize=None, name=None
            # )
            # tf.debugging.assert_equal(
            #     tf.shape(inputs_reduced)[1], cut_off_int, message=None, summarize=None, name=None
            # )
            return inputs_reduced, attention_mask_reduced
        else:
            return inputs, attention_mask