8d52bd0b09152b02e0a5504d33593d0c290b88c7,ch14/06_train_d4pg.py,,,#,47
Before Change
q_v = crt_net(states_v, actions_v)
last_act_v = tgt_act_net.target_model(last_states_v)
q_last_v = tgt_crt_net.target_model(last_states_v, last_act_v)
q_last_v[dones_mask] = 0.0
q_ref_v = rewards_v.unsqueeze(dim=-1) + q_last_v * (GAMMA ** REWARD_STEPS)
critic_loss_v = F.mse_loss(q_v, q_ref_v.detach())
critic_loss_v.backward()
After Change
crt_opt.zero_grad()
crt_distr_v = crt_net(states_v, actions_v)
last_act_v = tgt_act_net.target_model(last_states_v)
last_distr_v = F.softmax(tgt_crt_net.target_model(last_states_v, last_act_v))
proj_distr_v = distr_projection(last_distr_v, rewards_v, dones_mask,
gamma=GAMMA**REWARD_STEPS, cuda=args.cuda)
prob_dist_v = -F.log_softmax(crt_distr_v) * proj_distr_v
critic_loss_v = prob_dist_v.sum(dim=1).mean()
critic_loss_v.backward()
crt_opt.step()
tb_tracker.track("loss_critic", critic_loss_v, frame_idx)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 6
Instances
Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 8d52bd0b09152b02e0a5504d33593d0c290b88c7
Time: 2018-02-05
Author: max.lapan@gmail.com
File Name: ch14/06_train_d4pg.py
Class Name:
Method Name:
Project Name: catalyst-team/catalyst
Commit Name: c27dbde9ccec2920f3825538aff07e8533e086ba
Time: 2019-07-24
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _categorical_loss