Returns:
A dict of metrics from training.
self._losses = AverageMeter()
self.model.train()
with self.timers["epoch_time"]:
for batch_idx, batch in enumerate(iterator):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
if "loss" in metrics:
self._losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
self.global_step += 1
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
self.scheduler.step()
stats = {
BATCH_COUNT: batch_idx + 1,
"mean_train_loss": self._losses.avg,
"last_train_loss": self._losses.val,
"epoch_time": self.timers["epoch_time"].last
}
stats.update({
timer_tag: timer.mean
for timer_tag, timer in self.timers.items()
})
return stats
def train_batch(self, batch, batch_info):
Computes loss and updates the model over one batch.
After Change
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH: