ctx.num_bwd_passes = num_bwd_passes
with torch.no_grad():
x = input_t.detach() // Makes a detached copy which shares the storage
output = ctx.fn(x)
detached_output = output.detach_() // Detaches y in-place (inbetween computations can now be discarded)
After Change
ctx.num_bwd_passes = num_bwd_passes
ctx.num_inputs = num_inputs
input_t = inputs_and_weights[:num_inputs]
ctx.input_requires_grad = [element.requires_grad for element in input_t]
with torch.no_grad():
// Makes a detached copy which shares the storage