ctx = arr.context
groups[ctx].append(arr)
return groups
norm_groups = group_by_ctx(norm_arrays)
// reduce
ctx, dtype = arrays[0].context, "float32"
norms = [nd.add_n(*g).as_in_context(ctx) for g in norm_groups.values()]
total_norm = nd.add_n(*norms).sqrt()
scale = total_norm / max_norm
// is_finite = 0 if NaN or Inf, 1 otherwise.
After Change
Batch size of data processed. Gradient will be normalized by `1/batch_size`.
Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
max_norm : NDArray, optional, default is None
max value for global 2-norm of gradients.
"""
self.fp32_trainer.allreduce_grads()
step_size = batch_size * self._scaler.loss_scale
if max_norm: