self.global_step = global_step
gen_average_steps += 1
gen_train_fraction -= 1.0
self.global_step = global_step + 1
// Write checkpoints and report progress.
if discrim_average_steps == checkpoint_interval:
saver.save(self.session, self.save_file, global_step=self.global_step)
discrim_loss = discrim_error / max(1, discrim_average_steps)
gen_loss = gen_error / max(1, gen_average_steps)
print(
"Ending global_step %d: generator average loss %g, discriminator average loss %g"
After Change
with self._get_tf("Graph").as_default():
if checkpoint_interval > 0:
manager = tf.train.CheckpointManager(
self._get_tf("Checkpoint"), self.model_dir, max_checkpoints_to_keep)
for feed_dict in batches:
// Every call to fit_generator() will increment global_step, but we only
// want it to get incremented once for the entire batch, so record the
// value and keep resetting it.
global_step = self.global_step
// Train the discriminator.
feed_dict = dict(feed_dict)
feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
discrim_error += self.fit_generator(
[feed_dict],
submodel=self.discriminator_submodel,
checkpoint_interval=0)
self.global_step = global_step
discrim_average_steps += 1
// Train the generator.
if generator_steps > 0.0:
gen_train_fraction += generator_steps
while gen_train_fraction >= 1.0:
feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
gen_error += self.fit_generator(
[feed_dict],
submodel=self.generator_submodel,
checkpoint_interval=0)
self.global_step = global_step
gen_average_steps += 1
gen_train_fraction -= 1.0
self.global_step = global_step + 1
// Write checkpoints and report progress.
if discrim_average_steps == checkpoint_interval:
self._exec_with_session(lambda: manager.save())
discrim_loss = discrim_error / max(1, discrim_average_steps)
gen_loss = gen_error / max(1, gen_average_steps)
print(
"Ending global_step %d: generator average loss %g, discriminator average loss %g"