aab3902d4a7d55f5a86058854adc36b8a12c873f,catalyst/dl/callbacks/base.py,OptimizerCallback,on_batch_end,#OptimizerCallback#Any#,202

Before Change


                self._accumulation_counter = 0
        else:
            model = state.model
            model.zero_grad()
            optimizer = state.get_key(
                key="optimizer", inner_key=self.optimizer_key
            )
            loss = state.get_key(key="loss", inner_key=self.optimizer_key)

After Change


    def on_batch_end(self, state):
        loss = state.get_key(key="loss", inner_key=self.loss_key)
        if isinstance(loss, dict):
            loss = list(loss.values())
        if isinstance(loss, list):
            loss = torch.mean(torch.stack(loss))

        if self.prefix is not None:
            state.metrics.add_batch_value(metrics_dict={
                self.prefix: loss.item(),
            })

        if not state.need_backward:
            return

        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(
            key="optimizer", inner_key=self.optimizer_key
        )

        // This is very hacky check whether we have AMP optimizer and this may
        // change in future.
        // But alternative solution is to have AmpOptimizerCallback.
        // or expose another c"tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: Scitator/catalyst
Commit Name: aab3902d4a7d55f5a86058854adc36b8a12c873f
Time: 2019-05-20
Author: ekhvedchenya@gmail.com
File Name: catalyst/dl/callbacks/base.py
Class Name: OptimizerCallback
Method Name: on_batch_end


Project Name: junyanz/BicycleGAN
Commit Name: 6f0eec8ef0a147a80f64f25103089feb33553a06
Time: 2020-06-27
Author: junyanzhu89@gmail.com
File Name: models/bicycle_gan_model.py
Class Name: BiCycleGANModel
Method Name: update_G_and_E