"""Turn a n-D tensor into a 2D tensor where
the first dimension is conserved.
"""
x = tf.reshape(x, [-1, np.prod(x.get_shape()[1:].as_list())])
return x
def expand_dims(x, dim=-1):
"""Add a 1-sized dimension at index "dim".
After Change
def flatten(x):
return tf.reshape(x, [-1])
def batch_flatten(x):
"""Turn a n-D tensor into a 2D tensor where