def _interp_fit(self, y0, y1, k, dt):
Fit an interpolating polynomial to the results of a Runge-Kutta step.
y_mid = y0 + k.matmul(dt * self.mid).view_as(y0).type_as(y0) // mid is float64 so cast back
f0 = k[..., 0]
f1 = k[..., -1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
After Change
def _interp_fit(self, y0, y1, k, dt):
Fit an interpolating polynomial to the results of a Runge-Kutta step.
dt = dt.type_as(y0)
y_mid = y0 + k.matmul(dt * self.mid).view_as(y0)
f0 = k[..., 0]
f1 = k[..., -1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)