// This is so that we create a duplicate of weights into CPU rather
// than move the model weights out of the GPU so that we can
// resume training while saving intermediate checkpoints.
cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}]
return cpu_state_dicts
def _set_model_state_dicts(self, model_state_dicts):
for model, model_state_dict in zip(self.models, model_state_dicts):
model.module.load_state_dict(model_state_dict)
After Change
This is needed for PyTorch DistributedDataParallel models.
return [model.module.state_dict() for model in self.models]
def _set_model_state_dicts(self, model_state_dicts):
for model, model_state_dict in zip(self.models, model_state_dicts):
model.module.load_state_dict(model_state_dict)