util.single_node_env()
// load saved_model using default tag and signature
sess = tf.Session()
tf.saved_model.loader.load(sess, ["serve"], args.export)
// parse function for TFRecords
def parse_tfr(example_proto):
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
features = tf.parse_single_example(example_proto, feature_def)
norm = tf.constant(255, dtype=tf.float32, shape=(784,))
image = tf.div(tf.to_float(features["image"]), norm)
label = tf.to_float(features["label"])
return (image, label)
// define a new tf.data.Dataset (for inferencing)
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
ds = ds.shard(num_workers, worker_num)
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1)
ds = ds.map(parse_tfr).batch(10)
iterator = ds.make_one_shot_iterator()
image_label = iterator.get_next(name="inf_image")
// create an output file per spark worker for the predictions
tf.gfile.MakeDirs(args.output)
output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode="w")
while True:
try:
// get images and labels from tf.data.Dataset
img, lbl = sess.run(["inf_image:0", "inf_image:1"])
// inference by feeding these images and labels into the input tensors
// you can view the exported model signatures via:
// saved_model_cli show --dir <export_dir> --all
// note that we feed directly into the graph tensors (bypassing the exported signatures)
// these tensors will be shown in the "name" field of the signature definitions
outputs = sess.run(["dense_2/Softmax:0"], feed_dict={"Placeholder:0": img})
for p in outputs[0]:
output_file.write("{}\n".format(np.argmax(p)))
except tf.errors.OutOfRangeError:
break