self._accumulation_counter = 0
else:
model = state.model
model.zero_grad()
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
loss = state.get_key(key="loss", inner_key=self.optimizer_key)
After Change
def on_batch_end(self, state):
loss = state.get_key(key="loss", inner_key=self.loss_key)
if isinstance(loss, dict):
loss = list(loss.values())
if isinstance(loss, list):
loss = torch.mean(torch.stack(loss))
if self.prefix is not None:
state.metrics.add_batch_value(metrics_dict={
self.prefix: loss.item(),
})
if not state.need_backward:
return
self._accumulation_counter += 1
model = state.model
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
// This is very hacky check whether we have AMP optimizer and this may
// change in future.
// But alternative solution is to have AmpOptimizerCallback.
// or expose another c"tor argument.
if hasattr(optimizer, "_amp_stash"):
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if (self._accumulation_counter + 1) % self.accumulation_steps == 0: