def main(argv): // pylint: disable=unused-argument
logging.set_verbosity(logging.INFO)
if FLAGS.gcs_upload and FLAGS.project is None:
raise ValueError("GCS Project must be provided.")
if FLAGS.gcs_upload and FLAGS.gcs_output_path is None:
raise ValueError("GCS output path must be provided.")
elif FLAGS.gcs_upload and not FLAGS.gcs_output_path.startswith("gs://"):
raise ValueError("GCS output path must start with gs://")
if FLAGS.local_scratch_dir is None:
raise ValueError("Scratch directory path must be provided.")
// Download the dataset if it is not present locally
raw_data_dir = FLAGS.raw_data_dir
if raw_data_dir is None:
raise AssertionError(
"The ImageNet download path is no longer supported. Please download "
"the .tar files manually and provide the `raw_data_dir`.")
// Convert the raw data into tf-records
training_records, validation_records = convert_to_tf_records(raw_data_dir)
// Upload to GCS
if FLAGS.gcs_upload:
upload_to_gcs(training_records, validation_records)
if __name__ == "__main__":
app.run(main)