# -*- coding: utf-8 -*-
import tensorflow as tf
import re


class LazyAdam(tf.keras.optimizers.Adam):
    """Variant of the Adam optimizer that handles sparse updates more efficiently.

  The original Adam algorithm maintains two moving-average accumulators for
  each trainable variable; the accumulators are updated at every step.
  This class provides lazier handling of gradient updates for sparse
  variables.  It only updates moving-average accumulators for sparse variable
  indices that appear in the current batch, rather than updating the
  accumulators for all indices. Compared with the original Adam optimizer,
  it can provide large improvements in model training throughput for some
  applications. However, it provides slightly different semantics than the
  original Adam algorithm, and may lead to different empirical results.
  Note, amsgrad is currently not supported and the argument can only be
  False.

  This class is borrowed from:
  https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/lazy_adam.py
  """

    def _resource_apply_sparse(self, grad, var, indices):
        """Applies grad for one step."""
        # file_writer = tf.summary.create_file_writer('/Users/barid/Documents/workspace/alpha/lip_read/model_summary' + "/gradient")
        # file_writer.set_as_default()
        tf.summary.histogram("gradient", grad)
        # tf.summary.merge_all()
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta_1_t = self._get_hyper("beta_1", var_dtype)
        beta_2_t = self._get_hyper("beta_2", var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)
        beta_1_power = tf.math.pow(beta_1_t, local_step)
        beta_2_power = tf.math.pow(beta_2_t, local_step)
        epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
        lr = lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power)

        # \\(m := beta1 * m + (1 - beta1) * g_t\\)
        m = self.get_slot(var, "m")
        m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad

        m_update_kwargs = {
            "resource": m.handle,
            "indices": indices,
            "updates": m_t_slice,
        }
        m_update_op = tf.raw_ops.ResourceScatterUpdate(**m_update_kwargs)

        # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
        v = self.get_slot(var, "v")
        v_t_slice = beta_2_t * tf.gather(v, indices) + (1 - beta_2_t) * tf.math.square(
            grad
        )

        v_update_kwargs = {
            "resource": v.handle,
            "indices": indices,
            "updates": v_t_slice,
        }
        v_update_op = tf.raw_ops.ResourceScatterUpdate(**v_update_kwargs)

        # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
        var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t)

        var_update_kwargs = {
            "resource": var.handle,
            "indices": indices,
            "updates": var_slice,
        }
        var_update_op = tf.raw_ops.ResourceScatterSub(**var_update_kwargs)

        return tf.group(*[var_update_op, m_update_op, v_update_op])


class AdamWeightDecay(tf.keras.optimizers.Adam):
    """Adam enables L2 weight decay and clip_by_global_norm on gradients.

  Just adding the square of the weights to the loss function is *not* the
  correct way of using L2 regularization/weight decay with Adam, since that will
  interact with the m and v parameters in strange ways.

  Instead we want ot decay the weights in a manner that doesn't interact with
  the m/v parameters. This is equivalent to adding the square of the weights to
  the loss with plain (non-momentum) SGD.
  """

    def __init__(
        self,
        learning_rate=0.001,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7,
        amsgrad=False,
        weight_decay_rate=0.01,
        include_in_weight_decay=None,
        exclude_from_weight_decay=None,
        # clipnorm=1.0,
        name="AdamWeightDecay",
        **kwargs
    ):
        super(AdamWeightDecay, self).__init__(
            learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs
        )
        self.weight_decay_rate = weight_decay_rate
        self._include_in_weight_decay = include_in_weight_decay
        self._exclude_from_weight_decay = exclude_from_weight_decay

    # @classmethod
    # def from_config(cls, config):
    #     """Creates an optimizer from its config with WarmUp custom object."""
    #     custom_objects = {"WarmUp": warmup_lr}
    #     return super(AdamWeightDecay, cls).from_config(
    #         config, custom_objects=custom_objects
    #     )

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(AdamWeightDecay, self)._prepare_local(
            var_device, var_dtype, apply_state)
        apply_state["weight_decay_rate"] = tf.constant(
            self.weight_decay_rate, name="adam_weight_decay_rate"
        )

    def _decay_weights_op(self, var, learning_rate, apply_state):
        do_decay = self._do_use_weight_decay(var.name)
        if do_decay:
            return var.assign_sub(
                learning_rate * var * apply_state["weight_decay_rate"],
                use_locking=self._use_locking,
            )
        return tf.no_op()

    # def apply_gradients(self, grads_and_vars, name=None):
    #     grads, tvars = list(zip(*grads_and_vars))
    #     (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
    #     return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))

    def _get_lr(self, var_device, var_dtype, apply_state):
        """Retrieves the learning rate with the given state."""
        if apply_state is None:
            return self._decayed_lr_t[var_dtype], {}

        apply_state = apply_state or {}
        coefficients = apply_state.get((var_device, var_dtype))
        if coefficients is None:
            coefficients = self._fallback_apply_state(var_device, var_dtype)
            apply_state[(var_device, var_dtype)] = coefficients

        return coefficients["lr_t"], dict(apply_state=apply_state)

    def _resource_apply_dense(self, grad, var, apply_state=None):
        lr_t, kwargs = self._get_lr(
            var.device, var.dtype.base_dtype, apply_state)
        decay = self._decay_weights_op(var, lr_t, apply_state)
        with tf.control_dependencies([decay]):
            return super(AdamWeightDecay, self)._resource_apply_dense(
                grad, var, **kwargs
            )

    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        lr_t, kwargs = self._get_lr(
            var.device, var.dtype.base_dtype, apply_state)
        decay = self._decay_weights_op(var, lr_t, apply_state)
        with tf.control_dependencies([decay]):
            return super(AdamWeightDecay, self)._resource_apply_sparse(
                grad, var, indices, **kwargs
            )

    def get_config(self):
        config = super(AdamWeightDecay, self).get_config()
        config.update(
            {"weight_decay_rate": self.weight_decay_rate, }
        )
        return config

    def _do_use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if self.weight_decay_rate == 0:
            return False

        if self._include_in_weight_decay:
            for r in self._include_in_weight_decay:
                if re.search(r, param_name) is not None:
                    return True

        if self._exclude_from_weight_decay:
            for r in self._exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True
