def flatten(x):
// TODO: `keras_shape` inference.
return T.flatten(x)
def batch_flatten(x):
Turn a n-D tensor into a 2D tensor where
After Change
def flatten(x):
y = T.flatten(x)
if hasattr(x, "_keras_shape"):
y._keras_shape = (np.prod(x._keras_shape), )
return y
def batch_flatten(x):
Turn a n-D tensor into a 2D tensor where