// Get the sets of images and labels for training, validation, and test on MNIST.
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
for i in (1, 2, 3, 4):
start_time = time.time()
// Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
graph = MNISTGraph(
learning_rate=FLAGS.learning_rate, hidden1=FLAGS.hidden1//2*i,
After Change
Train MNIST for a number of steps.
// Get the sets of images and labels for training, validation, and test on MNIST.
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
best_geometry = brute_force_optimal_network_geometry(data_sets, FLAGS.training_precision)
print(best_geometry)
start_time = time.time()
with tf.Graph().as_default():
graph = MNISTGraph(
learning_rate=FLAGS.learning_rate,
hidden1=best_geometry[3][0], hidden2=best_geometry[3][1], hidden3=best_geometry[3][2],
batch_size=FLAGS.batch_size, train_dir=FLAGS.train_dir
)
graph.train(data_sets, FLAGS.max_steps, precision=FLAGS.desired_precision)