"For more information, see ""https://github.com/pytorch/examples/issues/467."))
ifnot(callable(model_creator) and callable(optimizer_creator)):
raise ValueError(
"Must provide a callable model_creator and optimizer_creator.")if num_replicas is not None:
raise DeprecationWarning(
"num_replicas is deprecated. Use num_workers instead.")if batch_size is not None:
raise DeprecationWarning(
"batch_size is deprecated. Use config={"batch_size": N} ""specify a batch size for each worker or ""config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a ""batch size to be used across all workers.")if data_loader_args:
raise ValueError(
"data_loader_args is deprecated. You can return a ""torch.utils.data.DataLoader in data_creator. Ray will ""automatically set a DistributedSampler if a DataLoader is ""returned and num_workers > 1.")
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.data_creator = data_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls
if not training_operator_cls and not loss_creator:
raise ValueError("If a loss_creator is not provided, you must ""provide a custom training operator.") self.initialization_hook = initialization_hook
self.config = {} if config is None else config
if use_gpu == "auto":
use_gpu = torch.cuda.is_available()