self._input_channels = input_channels
if not tf.float32.is_compatible_with(inputs.dtype):
raise TypeError("Input must have dtype tf.float32, but dtype was " +
inputs.dtype.name)
depthwise_weight_shape = (self._kernel_shape[0], self._kernel_shape[1],
self._input_channels, self._channel_multiplier)
pointwise_input_size = self._channel_multiplier * self._input_channels
pointwise_weight_shape = (1, 1, pointwise_input_size, self._output_channels)
bias_shape = (self._output_channels,)
if "w_dw" not in self._initializers:
fan_in_shape = depthwise_weight_shape[:2]
self._initializers["w_dw"] = create_weight_initializer(fan_in_shape)
if "w_pw" not in self._initializers:
fan_in_shape = pointwise_weight_shape[:3]
self._initializers["w_pw"] = create_weight_initializer(fan_in_shape)
After Change
self._input_channels = input_channels
_verify_inputs_dtype(inputs)
depthwise_weight_shape = (self._kernel_shape[0], self._kernel_shape[1],
self._input_channels, self._channel_multiplier)
pointwise_input_size = self._channel_multiplier * self._input_channels
pointwise_weight_shape = (1, 1, pointwise_input_size, self._output_channels)
bias_shape = (self._output_channels,)
if "w_dw" not in self._initializers:
fan_in_shape = depthwise_weight_shape[:2]
self._initializers["w_dw"] = create_weight_initializer(fan_in_shape,
dtype=inputs.dtype)
if "w_pw" not in self._initializers:
fan_in_shape = pointwise_weight_shape[:3]
self._initializers["w_pw"] = create_weight_initializer(fan_in_shape,
dtype=inputs.dtype)
if "b" not in self._initializers and self._use_bias:
self._initializers["b"] = create_bias_initializer(bias_shape,
dtype=inputs.dtype)
self._w_dw = tf.get_variable(
"w_dw",
shape=depthwise_weight_shape,
dtype=inputs.dtype,
initializer=self._initializers["w_dw"],
partitioner=self._partitioners.get("w_dw", None),
regularizer=self._regularizers.get("w_dw", None))
self._w_pw = tf.get_variable(
"w_pw",
shape=pointwise_weight_shape,
dtype=inputs.dtype,
initializer=self._initializers["w_pw"],
partitioner=self._partitioners.get("w_pw", None),
regularizer=self._regularizers.get("w_pw", None))