if not self.skip:
self.former = in_data.narrow(1, 0, self.c//2)
self.former_in_mask = torch.mul(self.former, self.mask)
self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach().cuda() // the latter part
self.target = self.target * self.strength
self.loss = self.criterion(self.former_in_mask * self.strength, self.target)
else:
self.loss = 0
After Change
// For each iteration, we input GT, then I. That means we get the self.target in the first forward. And in this forward, self.loss is dummy!
// In the second forward, we input the corresponding I, then self.loss is working as expected. The self.target is the corresponding GT.
self.loss = self.criterion(self.former_in_mask * self.strength, self.target.expand_as(self.former_in_mask).type_as(self.former_in_mask))
self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach() // the latter part
self.target = self.target * self.strength
else:
self.loss = 0
self.output = in_data