1d4823c0ec446e93d00df8ca654db4b45b63b3d4,rllib/models/tests/test_distributions.py,TestDistributions,test_gumbel_softmax,#TestDistributions#,172
Before Change
def test_gumbel_softmax(self):
Tests the GumbelSoftmax ActionDistribution (tf-eager only).
with eager_mode():
batch_size = 1000
num_categories = 5
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
After Change
def test_gumbel_softmax(self):
Tests the GumbelSoftmax ActionDistribution (tf-eager only).
for fw, sess in framework_iterator(
frameworks=["tf", "eager"], session=True):
batch_size = 1000
num_categories = 5
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
// Batch of size=n and deterministic.
inputs = input_space.sample()
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
expected = softmax(inputs)
// Sample n times, expect always mean value (deterministic draw).
out = gumbel_softmax.deterministic_sample()
check(out, expected)
// Batch of size=n and non-deterministic -> expect roughly that
// the max-likelihood (argmax) ints are output (most of the time).
inputs = input_space.sample()
gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
outs = gumbel_softmax.sample()
if sess:
outs = sess.run(outs)
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 5
Instances
Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/models/tests/test_distributions.py
Class Name: TestDistributions
Method Name: test_gumbel_softmax
Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/agents/dqn/tests/test_dqn.py
Class Name: TestDQN
Method Name: test_dqn_compilation
Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/models/tests/test_distributions.py
Class Name: TestDistributions
Method Name: test_squashed_gaussian