k = _UncheckedAssign.apply(k, f0, (..., 0))
for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)):
ti = t0 + alpha_i * dt
yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0).type_as(y0) // tableau is float 64 so cast back
f = func(ti, yi)
k = _UncheckedAssign.apply(k, f, (..., i + 1))
if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()):
// This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0).type_as(y0) // tableau is float 64 so cast back
y1 = yi
f1 = k[..., -1]
y1_error = k.matmul(dt * tableau.c_error)
return y1, f1, y1_error, k
After Change
t0 = t0.type_as(y0)
dt = dt.type_as(y0)
// We use an unchecked assign to put data into k without incrementing its _version counter, so that the backward
// doesn"t throw an (overzealous) error about in-place correctness. We know that it"s actually correct.
k = torch.empty(*f0.shape, len(tableau.alpha) + 1, dtype=y0.dtype, device=y0.device)
k = _UncheckedAssign.apply(k, f0, (..., 0))
for i, (alpha_i, beta_i) in enumerate(zip(tableau.alpha, tableau.beta)):
ti = t0 + alpha_i * dt
yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0)
f = func(ti, yi)
k = _UncheckedAssign.apply(k, f, (..., i + 1))
if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()):
// This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0)
y1 = yi
f1 = k[..., -1]
y1_error = k.matmul(dt * tableau.c_error)
return y1, f1, y1_error, k