self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer.")
After Change
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer.")