number of boxes detected for each sample in the batch.
_, _, width, _ = _get_shape(height_width_predictions, 4)
peak_spatial_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
y_indices = _to_float32(y_indices)
After Change
// tf_gather_nd instead and here we prepare the indices for that.
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_boxes),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
new_height_width = tf.gather_nd(height_width_predictions, combined_indices)
new_height_width = tf.reshape(new_height_width, [batch_size, num_boxes, -1])
new_offsets = tf.gather_nd(offset_predictions, combined_indices)
offsets = tf.reshape(new_offsets, [batch_size, num_boxes, -1])
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)