9e9c4fd3322b6e8f47572fefdb8fd65018fb96f7,rl_coach/exploration_policies/parameter_noise.py,ParameterNoise,get_action,#ParameterNoise#Any#,62

Before Change



    def get_action(self, action_values: List[ActionType]) -> ActionType:
        if type(self.action_space) == DiscreteActionSpace:
            return np.argmax(action_values)
        elif type(self.action_space) == BoxActionSpace:
            action_values_mean = action_values[0].squeeze()
            action_values_std = action_values[1].squeeze()
            return np.random.normal(action_values_mean, action_values_std)

After Change



    def get_action(self, action_values: List[ActionType]):
        if type(self.action_space) == DiscreteActionSpace:
            action = np.argmax(action_values)
            one_hot_action_probabilities = np.zeros(len(self.action_space.actions))
            one_hot_action_probabilities[action] = 1

            return action, one_hot_action_probabilities
        elif type(self.action_space) == BoxActionSpace:
            action_values_mean = action_values[0].squeeze()
            action_values_std = action_values[1].squeeze()
            return np.random.normal(action_values_mean, action_values_std)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 13

Instances


Project Name: NervanaSystems/coach
Commit Name: 9e9c4fd3322b6e8f47572fefdb8fd65018fb96f7
Time: 2019-05-27
Author: gal.leibovich@intel.com
File Name: rl_coach/exploration_policies/parameter_noise.py
Class Name: ParameterNoise
Method Name: get_action


Project Name: NervanaSystems/coach
Commit Name: 9e9c4fd3322b6e8f47572fefdb8fd65018fb96f7
Time: 2019-05-27
Author: gal.leibovich@intel.com
File Name: rl_coach/exploration_policies/categorical.py
Class Name: Categorical
Method Name: get_action


Project Name: NervanaSystems/coach
Commit Name: 9e9c4fd3322b6e8f47572fefdb8fd65018fb96f7
Time: 2019-05-27
Author: gal.leibovich@intel.com
File Name: rl_coach/exploration_policies/greedy.py
Class Name: Greedy
Method Name: get_action