dataset = dataset.batch(batch_size) // TODO: consider using tf.contrib.map_and_batch
dataset = dataset.prefetch(1) // prefetch 1 batch
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
sess = tf.Session()
// feed `one_element` into a network, for demo, we simply get the data as follows
n_step = round(n_epoch * n_data / batch_size)
st = time.time()
for _ in range(n_step):
_images, _targets = sess.run(one_element)
print("dataset APIs took %fs for each image" % ((time.time() - st) / batch_size / n_step)) // CPU ~ 100%
def example4():
After Change
for img, target in dataset:
n_step += 1
pass
assert n_step == n_epoch * n_data / batch_size
print("dataset APIs took %fs for each image" % ((time.time() - st) / batch_size / n_step)) // CPU ~ 100%
def example4():