093aca082e2878f3a28defe9075e7334dfceac70,test/distributions/test_distributions.py,TestDistributionShapes,test_one_hot_categorical_shape,#TestDistributionShapes#,3181
Before Change
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
simplex_sample = torch.ones(3, 3) / 3
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
After Change
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
sample = torch.tensor([0., 1., 0.]).expand(3, 2, 3)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
sample = torch.eye(3)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 4
Instances
Project Name: pytorch/pytorch
Commit Name: 093aca082e2878f3a28defe9075e7334dfceac70
Time: 2021-01-05
Author: fritz.obermeyer@gmail.com
File Name: test/distributions/test_distributions.py
Class Name: TestDistributionShapes
Method Name: test_one_hot_categorical_shape
Project Name: PIQuIL/QuCumber
Commit Name: 364948bf0bb46f9bc3809bbe0a63d09348994444
Time: 2018-08-04
Author: emerali@users.noreply.github.com
File Name: tests/positiveGrads_test.py
Class Name:
Method Name: partition
Project Name: PIQuIL/QuCumber
Commit Name: 364948bf0bb46f9bc3809bbe0a63d09348994444
Time: 2018-08-04
Author: emerali@users.noreply.github.com
File Name: tests/complexGrads_test.py
Class Name:
Method Name: partition