self.seed = seed
def __call__(self, shape, dtype=None):
flat_shape = (shape[0], np.prod(shape[1:]))
if self.seed is not None:
np.random.seed(self.seed)
a = np.random.normal(0.0, 1.0, flat_shape)
After Change
num_rows = 1
for dim in shape[:-1]:
num_rows *= dim
num_cols = shape[-1]
flat_shape = (num_rows, num_cols)
if self.seed is not None:
np.random.seed(self.seed)
a = np.random.normal(0.0, 1.0, flat_shape)