39a28aba95b0d8eeb56f7a17b7dad140e601b591,agents/clipped_ppo_agent.py,ClippedPPOAgent,choose_action,#ClippedPPOAgent#Any#Any#,180

Before Change



    def choose_action(self, curr_state, phase=RunPhase.TRAIN):
        // convert to batch so we can run it through the network
        observation = curr_state["observation"]
        observation = np.expand_dims(np.array(observation), 0)

        if self.env.discrete_controls:
            // DISCRETE

After Change


    def choose_action(self, current_state, phase=RunPhase.TRAIN):
        if self.env.discrete_controls:
            // DISCRETE
            _, action_values = self.main_network.online_network.predict(self.tf_input_state(current_state))
            action_values = action_values.squeeze()

            if phase == RunPhase.TRAIN:
                action = self.exploration_policy.get_action(action_values)
            else:
                action = np.argmax(action_values)
            action_info = {"action_probability": action_values[action]}
            // self.entropy.add_sample(-np.sum(action_values * np.log(action_values)))
        else:
            // CONTINUOUS
            _, action_values_mean, action_values_std = self.main_network.online_network.predict(self.tf_input_state(current_state))
            action_values_mean = action_values_mean.squeeze()
            action_values_std = action_values_std.squeeze()
            if phase == RunPhase.TRAIN:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 3

Instances


Project Name: NervanaSystems/coach
Commit Name: 39a28aba95b0d8eeb56f7a17b7dad140e601b591
Time: 2018-02-21
Author: zach.dwiel@intel.com
File Name: agents/clipped_ppo_agent.py
Class Name: ClippedPPOAgent
Method Name: choose_action


Project Name: NervanaSystems/coach
Commit Name: 98f57a0d8781083cbca137fb9d84313b94d9d108
Time: 2018-02-21
Author: zach.dwiel@intel.com
File Name: agents/ddpg_agent.py
Class Name: DDPGAgent
Method Name: choose_action


Project Name: NervanaSystems/coach
Commit Name: 8248caf35eeb43046f2b28937627d43cbf950c9a
Time: 2018-02-21
Author: zach.dwiel@intel.com
File Name: agents/policy_gradients_agent.py
Class Name: PolicyGradientsAgent
Method Name: choose_action


Project Name: NervanaSystems/coach
Commit Name: 8248caf35eeb43046f2b28937627d43cbf950c9a
Time: 2018-02-21
Author: zach.dwiel@intel.com
File Name: agents/ppo_agent.py
Class Name: PPOAgent
Method Name: choose_action


Project Name: NervanaSystems/coach
Commit Name: 39a28aba95b0d8eeb56f7a17b7dad140e601b591
Time: 2018-02-21
Author: zach.dwiel@intel.com
File Name: agents/clipped_ppo_agent.py
Class Name: ClippedPPOAgent
Method Name: choose_action