if hasattr(self.task_parameters, "checkpoint_restore_dir") and self.task_parameters.checkpoint_restore_dir:
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
elif self.task_parameters.framework_type == Frameworks.mxnet:
// TODO implement checkpoint restore
pass
else:
raise ValueError("Invalid framework {}".format(self.task_parameters.framework_type))
def occasionally_save_checkpoint(self):
// only the chief process saves checkpoints
if self.task_parameters.checkpoint_save_secs \