e.train()
// define an arbitrary reversible function and define graph for model 1
Xin = X.clone()
fn = ReversibleBlock(c1_2, c2_2, keep_input=False, coupling=coupling, adapter=AffineAdapterNaive,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
Y = fn.forward(Xin)
loss2 = torch.mean(Y)
// define the reversible function without custom backprop and define graph for model 2
XX = X.clone().data
After Change
e.train()
// define an arbitrary reversible function and define graph for model 1
Xin = X.clone().requires_grad_()
fn = ReversibleBlock(c1_2, c2_2, keep_input=False, coupling=coupling, adapter=AffineAdapterNaive,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
Y = fn.forward(Xin)
loss2 = torch.mean(Y)
// define the reversible function without custom backprop and define graph for model 2
XX = X.clone().detach().requires_grad_()