1d4823c0ec446e93d00df8ca654db4b45b63b3d4,rllib/models/tests/test_distributions.py,TestDistributions,test_gumbel_softmax,#TestDistributions#,172

Before Change



    def test_gumbel_softmax(self):
        Tests the GumbelSoftmax ActionDistribution (tf-eager only).
        with eager_mode():
            batch_size = 1000
            num_categories = 5
            input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))

After Change



    def test_gumbel_softmax(self):
        Tests the GumbelSoftmax ActionDistribution (tf-eager only).
        for fw, sess in framework_iterator(
                frameworks=["tf", "eager"], session=True):
            batch_size = 1000
            num_categories = 5
            input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))

            // Batch of size=n and deterministic.
            inputs = input_space.sample()
            gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)

            expected = softmax(inputs)
            // Sample n times, expect always mean value (deterministic draw).
            out = gumbel_softmax.deterministic_sample()
            check(out, expected)

            // Batch of size=n and non-deterministic -> expect roughly that
            // the max-likelihood (argmax) ints are output (most of the time).
            inputs = input_space.sample()
            gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
            expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
            outs = gumbel_softmax.sample()
            if sess:
                outs = sess.run(outs)
            check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)


if __name__ == "__main__":
    import pytest
    import sys
    sys.exit(pytest.main(["-v", __file__]))
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/models/tests/test_distributions.py
Class Name: TestDistributions
Method Name: test_gumbel_softmax


Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/agents/dqn/tests/test_dqn.py
Class Name: TestDQN
Method Name: test_dqn_compilation


Project Name: ray-project/ray
Commit Name: 1d4823c0ec446e93d00df8ca654db4b45b63b3d4
Time: 2020-04-03
Author: sven@anyscale.io
File Name: rllib/models/tests/test_distributions.py
Class Name: TestDistributions
Method Name: test_squashed_gaussian