y_true = K.clip(y_true, K.epsilon(), 1)
// make sure pred is a probability
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)y_pred = K.clip(y_pred, K.epsilon(), 1)
// Prepare the volumes to operate on
// If we"re doing "hard" Dice, then we will prepare one-hot-based matrices of size
// [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry,
// the entries are either 0 or 1
if self.dice_type == "hard":
// if given predicted probability, transform to "hard max""
if self.input_type == "prob":
if self.approx_hard_max:
y_pred_op = _hard_max(y_pred, axis=-1)
y_true_op = _hard_max(y_true, axis=-1)
else:
y_pred_op = _label_to_one_hot(K.argmax(y_pred, axis=-1), self.nb_labels)
y_true_op = _label_to_one_hot(K.argmax(y_true, axis=-1), self.nb_labels)
// if given predicted label, transform to one hot notation
else:
assert self.input_type == "max_label"
y_pred_op = _label_to_one_hot(y_pred, self.nb_labels)
y_true_op = _label_to_one_hot(y_true, self.nb_labels)
// If we"re doing soft Dice, require prob output, and the data already is as we need it
// [batch_size, nb_voxels, nb_labels]
else:
assert self.input_type == "prob", "cannot do soft dice with max_label input"
y_pred_op = y_pred
y_true_op = y_true
// reshape to [batch_size, nb_voxels, nb_labels]
batch_size = K.shape(y_true)[0]
y_pred_op = K.reshape(y_pred_op, [batch_size, -1, K.shape(y_true)[-1]])
y_true_op = K.reshape(y_true_op, [batch_size, -1, K.shape(y_true)[-1]])
// compute dice for each entry in batch.
// dice will now be [batch_size, nb_labels]
top = 2 * K.sum(y_true_op * y_pred_op, 1)
bottom = K.sum(K.square(y_true_op), 1) + K.sum(K.square(y_pred_op), 1)
// make sure we have no 0s on the bottom. K.epsilon()
bottom = K.maximum(bottom, self.area_reg)
return top / bottom
After Change
// make sure pred is a probability
if self.re_norm:
y_pred = tf.div_no_nan(y_pred, K.sum(y_pred, axis=-1, keepdims=True))
y_pred = K.clip(y_pred, K.epsilon(), 1)
// Prepare the volumes to operate on
// If we"re doing "hard" Dice, then we will prepare one-hot-based matrices of size
// [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry,
// the entries are either 0 or 1
if self.dice_type == "hard":
// if given predicted probability, transform to "hard max""
if self.input_type == "prob":
if self.approx_hard_max:
y_pred_op = _hard_max(y_pred, axis=-1)
y_true_op = _hard_max(y_true, axis=-1)
else:
y_pred_op = _label_to_one_hot(K.argmax(y_pred, axis=-1), self.nb_labels)
y_true_op = _label_to_one_hot(K.argmax(y_true, axis=-1), self.nb_labels)
// if given predicted label, transform to one hot notation
else:
assert self.input_type == "max_label"
y_pred_op = _label_to_one_hot(y_pred, self.nb_labels)
y_true_op = _label_to_one_hot(y_true, self.nb_labels)
// If we"re doing soft Dice, require prob output, and the data already is as we need it
// [batch_size, nb_voxels, nb_labels]
else:
assert self.input_type == "prob", "cannot do soft dice with max_label input"
y_pred_op = y_pred
y_true_op = y_true
// reshape to [batch_size, nb_voxels, nb_labels]
batch_size = K.shape(y_true)[0]
y_pred_op = K.reshape(y_pred_op, [batch_size, -1, K.shape(y_true)[-1]])
y_true_op = K.reshape(y_true_op, [batch_size, -1, K.shape(y_true)[-1]])
// compute dice for each entry in batch.
// dice will now be [batch_size, nb_labels]
top = 2 * K.sum(y_true_op * y_pred_op, 1)
bottom = K.sum(K.square(y_true_op), 1) + K.sum(K.square(y_pred_op), 1)
// make sure we have no 0s on the bottom. K.epsilon()
bottom = K.maximum(bottom, self.area_reg)
return top / bottom