// Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(states_batch), Variable(actions_batch))
old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices]
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs))
adv_targ = Variable(advantages.view(-1, 1)[indices])
surr1 = ratio * adv_targ
After Change
rollouts.cuda()
if args.algo == "ppo":
old_model = copy.deepcopy(actor_critic)
for j in range(num_updates):
for step in range(args.num_steps):
// Sample actions
value, action = actor_critic.act(Variable(rollouts.states[step], volatile=True))
cpu_actions = action.data.cpu().numpy()
// Obser reward and next state
state, reward, done, info = envs.step(cpu_actions)
reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
episode_rewards += reward
// If done then clean the history of observations.
masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
final_rewards *= masks
final_rewards += (1 - masks) * episode_rewards
episode_rewards *= masks
if args.cuda:
masks = masks.cuda()
current_state *= masks.unsqueeze(2).unsqueeze(2)
update_current_state(state)
rollouts.insert(step, current_state, action.data, value.data, reward, masks)
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
if args.algo in ["a2c", "acktr"]:
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(rollouts.states[:-1].view(-1, *obs_shape)), Variable(rollouts.actions.view(-1, 1)))
values = values.view(args.num_steps, args.num_processes, 1)
action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
advantages = Variable(rollouts.returns[:-1]) - values
value_loss = advantages.pow(2).mean()
action_loss = -(Variable(advantages.data) * action_log_probs).mean()
if args.algo == "acktr" and optimizer.steps % optimizer.Ts == 0:
// Sampled fisher, see Martens 2014
actor_critic.zero_grad()
pg_fisher_loss = -action_log_probs.mean()
value_noise = Variable(torch.randn(values.size()))
if args.cuda:
value_noise = value_noise.cuda()
sample_values = values + value_noise
vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
fisher_loss = pg_fisher_loss + vf_fisher_loss
optimizer.acc_stats = True
fisher_loss.backward(retain_graph=True)
optimizer.acc_stats = False
optimizer.zero_grad()
(value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()
if args.algo == "a2c":
nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
optimizer.step()
elif args.algo == "ppo":
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
old_model.load_state_dict(actor_critic.state_dict())
for _ in range(args.ppo_epoch):
sampler = BatchSampler(SubsetRandomSampler(range(args.num_processes * args.num_steps)), args.batch_size * args.num_processes, drop_last=False)
for indices in sampler:
indices = torch.LongTensor(indices)
if args.cuda:
indices = indices.cuda()
states_batch = rollouts.states[:-1].view(-1, *rollouts.states.size()[-3:])[indices]
actions_batch = rollouts.actions.view(-1, 1)[indices]
return_batch = rollouts.returns[:-1].view(-1, 1)[indices]
// Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(states_batch), Variable(actions_batch))
_, old_action_log_probs, _ = old_model.evaluate_actions(Variable(states_batch, volatile=True), Variable(actions_batch, volatile=True))
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs.data))
adv_targ = Variable(advantages.view(-1, 1)[indices])
surr1 = ratio * adv_targ