r2: idempotent [preds.shape[-1]] tensor of r2 values for each target.
update_op: op for updating the value given new data
preds_split = tf.unstack(preds, axis=-1)
targets_split = tf.unstack(targets, axis=-1)
weights_split = tf.unstack(weights, axis=-1)
r2_metrics = [
r2_metric(p, t, w)
for p, t, w in zip(preds_split, targets_split, weights_split)
]
r2_values = [r[0] for r in r2_metrics]
stacked_r2 = tf.stack(r2_values)
update_ops = tf.group(*[r[1] for r in r2_metrics])
return stacked_r2, update_ops
def r2_averaged_over_all_prediction_tasks(preds, targets, weights):
After Change
r2 = 1. - res_ss / tot_ss
update_op = tf.group(res_ss_update, tot_ss_update)
return r2, update_op
def r2_averaged_over_all_prediction_tasks(preds, targets, weights):
Returns ops for multi-task R2 statistic following the tf.metrics API.