// backward pass.
// TODO(b/171088214): Remove it after the control dependency in
// nested function is fixed.
with tf.control_dependencies([grad]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
After Change
// backward pass.
// TODO(b/171088214): Remove it after the control dependency in
// nested function is fixed.
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):