// initialize location and hidden state vectors
h_t = torch.zeros(self.batch_size, self.hidden_size)
l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
h_t, l_t = Variable(h_t), Variable(l_t)
for t in range(self.num_glimpses - 1):
// forward pass through model
h_t, l_t = self.model(img, l_t, h_t)
// bookeeping for later plotting
self.locs.append(l_t)
// last iteration
probas = self.model(img, l_t, h_t, last=True)
// to be continued
After Change
l_t = Variable(l_t)
// extract the glimpses
sum_grad_log_pi = 0.
for t in range(self.num_glimpses - 1):
// forward pass through model
self.h_t, mu, l_t = self.model(x, l_t, self.h_t)
// compute gradient of log of policy across batch
grad_log_pi = (mu-l_t) / (self.std*self.std)
// accumulate
sum_grad_log_pi += grad_log_pi
// last iteration
self.h_t, mu, l_t, b_t, log_probas = self.model(
img, l_t, self.h_t, last=True
)
// calculate reward
R = (torch.max(log_probas, 1)[1] == y)
// compute losses for differentiable modules
self.loss_action = F.nll_loss(log_probas, y)
self.loss_baseline = F.mse_loss(R, b_t)
// compute reinforce loss