ctx.num_bwd_passes = num_bwd_passes
with torch.no_grad():
x = input_t.detach() // Makes a detached copy which shares the storage
output = ctx.fn(x)
detached_output = output.detach_() // Detaches y in-place (inbetween computations can now be discarded)
// store these tensor nodes for backward pass
ctx.input_t = [input_t] * num_bwd_passes
After Change
if not isinstance(output, tuple):
output = (output,)
// Detaches y in-place (inbetween computations can now be discarded)
detached_output = tuple([element.detach_() for element in output])