batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
if logger.getEffectiveLevel() <= logging.DEBUG:
avg = averaged(iter_extra_fetches)
logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
After Change
tree.map_structure_with_path(
lambda p, *s: self._all_tower_reduce(p, *s),
*(batch_fetches["tower_{}".format(tower_num)]
for tower_num in range(len(self.devices)))))
// Reduce mean across all minibatch SGD steps (axis=0 to keep
// all shapes as-is).