// performs softmax if required
if self._softmax:
pred_b = tf.cast(pred_b, dtype=tf.float32)
pred_b = tf.nn.softmax(pred_b)
// reshape pred, ground_truth, weight_map to the same
// size: (n_voxels, num_classes)
// if the ground_truth has only one channel, the shape
// becomes: (n_voxels,)
spatial_shape = pred_b.get_shape().as_list()[:-1]
ref_shape = spatial_shape + [-1]
ground_truth_b = tf.reshape(ground_truth[b_ind], ref_shape)
if ground_truth_b.get_shape().as_list()[-1] == 1:
ground_truth_b = tf.squeeze(ground_truth_b, axis=-1)
if weight_map is not None:
weight_b = tf.reshape(weight_map[b_ind], ref_shape)
if weight_b.get_shape().as_list()[-1] == 1:
weight_b = tf.squeeze(weight_b, axis=-1)
else:
weight_b = None