loss = closure()
for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
lr = group["lr"]
for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
d_p_list.append(p.grad)
state = self.state[p]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
F.sgd(params_with_grad,
d_p_list,
momentum_buffer_list,
weight_decay,
momentum,
lr,
dampening,
nesterov)
// update momentum_buffers in state
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer
return loss