// reduce
ctx, dtype = arrays[0].context, "float32"
norms = [nd.add_n(*g).as_in_context(ctx) for g in norm_groups.values()]
total_norm = nd.add_n(*norms).sqrt()
scale = total_norm / max_norm
// is_finite = 0 if NaN or Inf, 1 otherwise.
After Change
mx.autograd.backward(ls)
def step(self, batch_size, max_norm=None):
Makes one step of parameter update. Should be called after
`fp16_optimizer.backward()`, and outside of `record()` scope.
Parameters
----------
batch_size : int
Batch size of data processed. Gradient will be normalized by `1/batch_size`.
Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
max_norm : NDArray, optional, default is None
max value for global 2-norm of gradients.
"""
self.fp32_trainer.allreduce_grads()
step_size = batch_size * self._scaler.loss_scale
if max_norm: