if step % self.D_update_ratio == 0 and step > self.D_init_iters:
// G
if self.cri_pix:
self.log_dict["l_g_pix"] = l_g_pix.data[0]
if self.cri_fea:
self.log_dict["l_g_fea"] = l_g_fea.data[0]
self.log_dict["l_g_gan"] = l_g_gan.data[0]
// D
After Change
if self.opt["train"]["gan_type"] == "wgan-gp":
self.log_dict["l_d_gp"] = l_d_gp.item()
// D outputs
self.log_dict["D_real"] = torch.mean(pred_d_real.detach())
self.log_dict["D_fake"] = torch.mean(pred_d_fake.detach())
def test(self):