// AutoGrad
with ag.record():
output = [net(X) for X in data]
loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
// Backpropagation
After Change
output = []
for _, X in enumerate(data):
X = X.reshape((-1,) + X.shape[2:])
pred = net(X)
output.append(pred)
loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
// Backpropagation