fec23179223cbde694b1611046221fc07a851824,deepchem/models/tensorgraph/models/gan.py,GAN,fit_gan,#GAN#Any#Any#Any#Any#Any#,338

Before Change


            self.global_step = global_step
            gen_average_steps += 1
            gen_train_fraction -= 1.0
        self.global_step = global_step + 1

        // Write checkpoints and report progress.

        if discrim_average_steps == checkpoint_interval:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          discrim_loss = discrim_error / max(1, discrim_average_steps)
          gen_loss = gen_error / max(1, gen_average_steps)
          print(
              "Ending global_step %d: generator average loss %g, discriminator average loss %g"

After Change


    with self._get_tf("Graph").as_default():
      if checkpoint_interval > 0:
        manager = tf.train.CheckpointManager(
            self._get_tf("Checkpoint"), self.model_dir, max_checkpoints_to_keep)
      for feed_dict in batches:
        // Every call to fit_generator() will increment global_step, but we only
        // want it to get incremented once for the entire batch, so record the
        // value and keep resetting it.

        global_step = self.global_step

        // Train the discriminator.

        feed_dict = dict(feed_dict)
        feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
        discrim_error += self.fit_generator(
            [feed_dict],
            submodel=self.discriminator_submodel,
            checkpoint_interval=0)
        self.global_step = global_step
        discrim_average_steps += 1

        // Train the generator.

        if generator_steps > 0.0:
          gen_train_fraction += generator_steps
          while gen_train_fraction >= 1.0:
            feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
            gen_error += self.fit_generator(
                [feed_dict],
                submodel=self.generator_submodel,
                checkpoint_interval=0)
            self.global_step = global_step
            gen_average_steps += 1
            gen_train_fraction -= 1.0
        self.global_step = global_step + 1

        // Write checkpoints and report progress.

        if discrim_average_steps == checkpoint_interval:
          self._exec_with_session(lambda: manager.save())
          discrim_loss = discrim_error / max(1, discrim_average_steps)
          gen_loss = gen_error / max(1, gen_average_steps)
          print(
              "Ending global_step %d: generator average loss %g, discriminator average loss %g"
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: deepchem/deepchem
Commit Name: fec23179223cbde694b1611046221fc07a851824
Time: 2019-04-05
Author: peastman@stanford.edu
File Name: deepchem/models/tensorgraph/models/gan.py
Class Name: GAN
Method Name: fit_gan


Project Name: deepchem/deepchem
Commit Name: fec23179223cbde694b1611046221fc07a851824
Time: 2019-04-05
Author: peastman@stanford.edu
File Name: deepchem/models/tensorgraph/tensor_graph.py
Class Name: TensorGraph
Method Name: save_checkpoint


Project Name: deepchem/deepchem
Commit Name: fec23179223cbde694b1611046221fc07a851824
Time: 2019-04-05
Author: peastman@stanford.edu
File Name: deepchem/models/tensorgraph/models/gan.py
Class Name: GAN
Method Name: fit_gan


Project Name: deepchem/deepchem
Commit Name: fec23179223cbde694b1611046221fc07a851824
Time: 2019-04-05
Author: peastman@stanford.edu
File Name: deepchem/models/tensorgraph/tensor_graph.py
Class Name: TensorGraph
Method Name: fit_generator