// Compute the PDF and CDF for each weight vector.
pdf = weights / weight_sum
cdf = jnp.cumsum(pdf, axis=-1)
cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1)
// Take uniform samples
if randomized:
After Change
// Compute the PDF and CDF for each weight vector, while ensuring that the CDF
// starts with exactly 0 and ends with exactly 1.
pdf = weights / weight_sum
cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
cdf = jnp.concatenate([
jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
jnp.ones(list(cdf.shape[:-1]) + [1])
],
axis=-1)