indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
// B x num_heads x 1 x num_atoms
logits_t = self.critic(states_t).gather(-2, indices_t).squeeze(-2)
// B x num_heads x num_atoms
all_logits_tp1 = self.target_critic(states_tp1).detach()
After Change
// B x num_heads x 1 x num_atoms
indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
// B x num_heads x num_actions x num_atoms
q_logits_t = self.critic(states_t)
// B x num_heads x num_atoms
logits_t = q_logits_t.gather(-2, indices_t).squeeze(-2)
// B x num_heads x num_actions x num_atoms
q_logits_tp1 = self.target_critic(states_tp1).detach()
q_values_tp1 = torch.sum(
torch.softmax(q_logits_tp1, dim=-1) * self.z, dim=-1
)
// B x num_heads x 1
actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
// B x num_heads x 1 x num_atoms
indices_tp1 = \
actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
// B x num_heads x num_atoms
logits_tp1 = q_logits_tp1.gather(-2, indices_tp1).squeeze(-2)
atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
value_loss = utils.categorical_loss(
logits_t.view(-1, self.num_atoms),
logits_tp1.view(-1, self.num_atoms),
atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z,
self.v_min, self.v_max
)
if self.entropy_regularization is not None:
q_values_t = torch.sum(
torch.softmax(q_logits_t, dim=-1) * self.z, dim=-1
)
value_loss -= \
self.entropy_regularization * self._compute_entropy(q_values_t)
return value_loss
def _quantile_loss(