093aca082e2878f3a28defe9075e7334dfceac70,test/distributions/test_distributions.py,TestDistributionShapes,test_one_hot_categorical_shape,#TestDistributionShapes#,3181

Before Change


        self.assertEqual(dist.sample().size(), torch.Size((3,)))
        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
        simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True)
        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
        self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
        simplex_sample = torch.ones(3, 3) / 3
        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))

After Change


        self.assertEqual(dist.sample().size(), torch.Size((3,)))
        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
        sample = torch.tensor([0., 1., 0.]).expand(3, 2, 3)
        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 2,)))
        self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
        sample = torch.eye(3)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 4

Instances


Project Name: pytorch/pytorch
Commit Name: 093aca082e2878f3a28defe9075e7334dfceac70
Time: 2021-01-05
Author: fritz.obermeyer@gmail.com
File Name: test/distributions/test_distributions.py
Class Name: TestDistributionShapes
Method Name: test_one_hot_categorical_shape


Project Name: PIQuIL/QuCumber
Commit Name: 364948bf0bb46f9bc3809bbe0a63d09348994444
Time: 2018-08-04
Author: emerali@users.noreply.github.com
File Name: tests/positiveGrads_test.py
Class Name:
Method Name: partition


Project Name: PIQuIL/QuCumber
Commit Name: 364948bf0bb46f9bc3809bbe0a63d09348994444
Time: 2018-08-04
Author: emerali@users.noreply.github.com
File Name: tests/complexGrads_test.py
Class Name:
Method Name: partition