b5e848af46b4a6fc21e718803dbf3d7c72afd72a,ch08/lib/common.py,,calc_loss,#,92

Before Change


def calc_loss(batch, net, tgt_net, gamma, cuda=False):
    states, actions, rewards, dones, next_states = unpack_batch(batch)

    states_v = Variable(torch.from_numpy(states))
    next_states_v = Variable(torch.from_numpy(next_states), volatile=True)
    actions_v = Variable(torch.from_numpy(actions))
    rewards_v = Variable(torch.from_numpy(rewards))
    done_mask = torch.ByteTensor(dones)
    if cuda:
        states_v = states_v.cuda()
        next_states_v = next_states_v.cuda()
        actions_v = actions_v.cuda()
        rewards_v = rewards_v.cuda()
        done_mask = done_mask.cuda()

    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    next_state_actions = net(next_states_v).max(1)[1]
    next_state_values = tgt_net(next_states_v).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1)
    next_state_values[done_mask] = 0.0

After Change


def calc_loss(batch, net, tgt_net, gamma, device="cpu"):
    states, actions, rewards, dones, next_states = unpack_batch(batch)

    states_v = torch.tensor(states).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.ByteTensor(dones).to(device)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 9

Instances


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: b5e848af46b4a6fc21e718803dbf3d7c72afd72a
Time: 2018-04-27
Author: max.lapan@gmail.com
File Name: ch08/lib/common.py
Class Name:
Method Name: calc_loss


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 7a6e3b93fb4b97af7b06244b768b1fee4b547c17
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch12/libbots/model.py
Class Name:
Method Name: pack_batch_no_out


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 84e2cf87e0fe23fe3432b74ab484b1b71cd729b3
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch18/lib/model.py
Class Name:
Method Name: state_lists_to_batch