```
"""
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
x_shape = (-1,) + int_shape(x)[1:]
y_shape = int_shape(y)
y_permute_dim = list(range(ndim(y)))
y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
xt = tf.reshape(x, [-1, x_shape[-1]])
After Change
"""
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
x_shape = []
for i, s in zip(int_shape(x), tf.unpack(tf.shape(x))):
if s is None:
x_shape.append(i)
else: