conv1 = func(name="NHWC", data_format="NHWC")
x = tf.constant(np.random.random(self.INPUT_SHAPE).astype(np.float32))
o1 = conv1(x)
// We will force both modules to share the same weights by creating
// a custom getter that returns the weights from the first conv module when
// tf.get_variable is called.
custom_getter = {"w": create_custom_field_getter(conv1, "w"),
"b": create_custom_field_getter(conv1, "b")}
conv2 = func(name="NCHW", data_format="NCHW", custom_getter=custom_getter)
x_transpose = tf.transpose(x, perm=(0, 3, 1, 2))o2 = tf.transpose(conv2(x_transpose), perm=(0, 2, 3, 1))self.checkEquality(o1, o2)
@parameterized.named_parameters(("WithBias", True), ("WithoutBias", False))
def testConv2DDataFormatsBatchNorm(self, use_bias):
Similar to `testConv2DDataFormats`, but this checks BatchNorm support.