z_target_tp1 = (self.z[None, :] - logprob_tp1[:, None]).detach()
// B x num_atoms
// Unsqueeze so its the same for each head
z_target_tp1 = z_target_tp1.unsqueeze(1)
atoms_target_t = rewards_t + (1 - done_t) * gammas * z_target_tp1
value_loss = [
utils.categorical_loss(
After Change
// [bs; num_atoms] -> unsqueeze so its the same for each head
// [bs; 1; num_atoms]
z_target_tp1 = (
self.z[None, :] - logprob_tp1[:, None]
).unsqueeze(1).detach()
// [bs; num_heads; num_atoms] -> many-heads view transform
// [{bs * num_heads}; num_atoms]
atoms_target_t = (