with torch.no_grad():
adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
adj_params = torch.zeros_like(flat_params)
adj_time = torch.tensor(0.)
time_vjps = []
for i in range(T - 1, 0, -1):
ans_i = tuple(ans_[i] for ans_ in ans)
After Change
// Run the augmented system backwards in time.
// TODO: switch this out to just not have an adj_params bit
if adj_params.numel() == 0:
adj_params = torch.zeros((), dtype=adj_y[0].dtype, device=adj_y[0].device)
aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
aug_ans = odeint(
augmented_dynamics, aug_y0,
t[i - 1:i + 1].flip(0),