f = x.shape[-1].value // c
g = f // c
x = tf.reshape(x, tf.concat([tf.shape(x)[:-1],
tf.constant([g, c, c])], axis=0))
x = tf.reduce_sum(x, axis=-2)
x = reshape(x, tf.concat([tf.shape(x)[:-2],
tf.constant([f])], axis=0), name="gconv")
return batch_norm(x)
After Change
start = (i // c) * c * c + i % c
end = start + c * c
kernel[:, :, start:end:c, i] = 1.
x = conv2d_primitive(x, tf.constant(kernel), strides=[1, 1, 1, 1],
padding="VALID", name="gconv")
return batch_norm(x)