if not self.out_type in ("array", "list"):
raise RuntimeError("The out_type must be one of "array" or "list".")
batch_shape = tuple(input.shape[:-2])
signal_shape = tuple(input.shape[-2:])
input = tf.reshape(input, (-1,) + signal_shape)
S = scattering2d(input, self.pad, self.unpad, self.backend, self.J, self.L, self.phi, self.psi,
self.max_order, self.out_type)
if self.out_type == "array":
scattering_shape = tuple(S.shape[-3:])
new_shape = batch_shape + scattering_shape
S = tf.reshape(S, new_shape)
else:
scattering_shape = tuple(S[0]["coef"].shape[-2:])
After Change
// Use tf.shape to get the dynamic shape of the tf.Tensors at
// execution time.
batch_shape = tf.shape(input)[:-2]
signal_shape = tf.shape(input)[-2:]
// NOTE: Cannot simply concatenate these using + since they are
// tf.Tensors and that would add their values.
input = tf.reshape(input, tf.concat(((-1,), signal_shape), 0))
S = scattering2d(input, self.pad, self.unpad, self.backend, self.J, self.L, self.phi, self.psi,
self.max_order, self.out_type)
if self.out_type == "array":
scattering_shape = tf.shape(S)[-3:]
new_shape = tf.concat((batch_shape, scattering_shape), 0)
S = tf.reshape(S, new_shape)
else:
scattering_shape = tf.shape(S[0]["coef"])[-2:]