tensor_input = True
y0 = (y0,)
func = TupleFunc(func)
params = tuple(func.parameters())
if adjoint_buffers:
params = params + tuple(buffer for buffer in func.buffers() if buffer.requires_grad)
n_tensors = len(y0)
ys = OdeintAdjointMethod.apply(func, t, rtol, atol, method, options, adjoint_rtol, adjoint_atol, adjoint_method,
adjoint_options, n_tensors, *params, *y0)
if tensor_input:
ys = ys[0]
return ys
After Change
adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)
if not tensor_input:
solution = _flat_to_shape(solution, tuple([len(t), *shape] for shape in shapes))
return solution