86df3d0466bb72a566fca457b108ef35a4ff6b14,catalyst/rl/offpolicy/algorithms/ddpg.py,DDPG,_quantile_loss,#DDPG#,111

Before Change


        policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t)))

        // critic loss (quantile regression)
        atoms_t = self.critic(states_t, actions_t)
        // B x num_heads x num_atoms
        atoms_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)
        ).detach()
        // B x num_heads x num_atoms

        done_t = done_t[:, None, :]
        // B x 1 x 1
        rewards_t = rewards_t[:, None, :]
        // B x 1 x 1
        gammas = gammas[None, :, None]
        // 1 x num_heads x 1

        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1

        value_loss = utils.quantile_loss(
            atoms_t.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms),
            self.tau, self.num_atoms,

After Change


        policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t)))

        // critic loss (quantile regression)
        atoms_t = self.critic(states_t, actions_t).squeeze_(dim=2)
        // B x num_heads x num_atoms
        atoms_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)
        ).squeeze_(dim=2).detach()
        // B x num_heads x num_atoms

        done_t = done_t[:, None, :]
        // B x 1 x 1
        rewards_t = rewards_t[:, None, :]
        // B x 1 x 1
        gammas = gammas[None, :, None]
        // 1 x num_heads x 1

        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1

        value_loss = utils.quantile_loss(
            atoms_t.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms),
            self.tau, self.num_atoms,
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 6

Instances


Project Name: catalyst-team/catalyst
Commit Name: 86df3d0466bb72a566fca457b108ef35a4ff6b14
Time: 2019-07-03
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _quantile_loss


Project Name: catalyst-team/catalyst
Commit Name: 86df3d0466bb72a566fca457b108ef35a4ff6b14
Time: 2019-07-03
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _base_loss


Project Name: catalyst-team/catalyst
Commit Name: 86df3d0466bb72a566fca457b108ef35a4ff6b14
Time: 2019-07-03
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _categorical_loss