# -*- coding: utf-8 -*-
import tensorflow as tf


class Gradient_Tower(object):
    def __init__(self, optimizer, n_gradients, trainable_variables, *args,
                 **kwargs):
        self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
        self.trainable_variables = trainable_variables
        self.optimizer = optimizer
        self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self.init = 1

    def tower_init(self):
        self.gradient_accumulation = [
            tf.Variable(tf.zeros(tf.shape(v), dtype=tf.float32),
                        trainable=False) for v in self.trainable_variables
        ]
        self.init = 0

    def tower_update(self, gradients, trainable_variables):
        self.trainable_variables = trainable_variables

        tf.cond(self.init, self.tower_init, lambda: None)

        self.n_acum_step.assign_add(1)
        # Accumulate batch gradients
        for i in range(len(self.gradient_accumulation)):
            try:
                self.gradient_accumulation[i].assign_add(gradients[i])
            except Exception:
                self.gradient_accumulation[i].assign_add(
                    tf.zeros_like(gradients[i]))

        # If n_acum_step reach the n_gradients then we apply accumulated gradients to update the variables otherwise do nothing
        tf.cond(tf.equal(self.n_acum_step, self.n_gradients),
                self.apply_accu_gradients, lambda: None)
        return

    def apply_accu_gradients(self):
        # apply accumulated gradients
        self.optimizer.apply_gradients(
            zip(self.gradient_accumulation, self.trainable_variables))

        # reset
        self.n_acum_step.assign(0)
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign(
                tf.zeros_like(self.trainable_variables[i], dtype=tf.float32))
        return