samples = []
for img in batch_inputs:
for knot in batch_knots:
samples.append(tf.gather_nd(img, knot))
samples = tf.stack(samples, axis=1)
// weighted average over N neighbours
return tf.reduce_sum(
After Change
samples = [tf.gather_nd(img, knot)
for (img, knot) in zip(batch_inputs, batch_knots)]
elif n_coords == 1 and batch_size > 1:
samples = [tf.gather_nd(img, batch_knots[0])
for img in batch_inputs]
else:
raise NotImplementedError
samples = tf.stack(samples, axis=1)
// weighted average over N neighbours