convolutional.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Simple, end-to-end, LeNet-5-like convolutional MNIST model example.
  16. This should achieve a test error of 0.7%. Please keep this model as simple and
  17. linear as possible, it is meant as a tutorial for simple convolutional models.
  18. Run with --self_test on the command line to execute a short self-test.
  19. """
  20. from __future__ import absolute_import
  21. from __future__ import division
  22. from __future__ import print_function
  23. import argparse
  24. import gzip
  25. import os
  26. import sys
  27. import time
  28. import numpy
  29. from six.moves import urllib
  30. from six.moves import xrange # pylint: disable=redefined-builtin
  31. import tensorflow as tf
  32. SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
  33. WORK_DIRECTORY = 'data'
  34. IMAGE_SIZE = 28
  35. NUM_CHANNELS = 1
  36. PIXEL_DEPTH = 255
  37. NUM_LABELS = 10
  38. VALIDATION_SIZE = 5000 # Size of the validation set.
  39. SEED = 66478 # Set to None for random seed.
  40. BATCH_SIZE = 64
  41. NUM_EPOCHS = 10
  42. EVAL_BATCH_SIZE = 64
  43. EVAL_FREQUENCY = 100 # Number of steps between evaluations.
  44. FLAGS = None
  45. def data_type():
  46. """Return the type of the activations, weights, and placeholder variables."""
  47. if FLAGS.use_fp16:
  48. return tf.float16
  49. else:
  50. return tf.float32
  51. def maybe_download(filename):
  52. """Download the data from Yann's website, unless it's already here."""
  53. if not tf.gfile.Exists(WORK_DIRECTORY):
  54. tf.gfile.MakeDirs(WORK_DIRECTORY)
  55. filepath = os.path.join(WORK_DIRECTORY, filename)
  56. if not tf.gfile.Exists(filepath):
  57. filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
  58. with tf.gfile.GFile(filepath) as f:
  59. size = f.size()
  60. print('Successfully downloaded', filename, size, 'bytes.')
  61. return filepath
  62. def extract_data(filename, num_images):
  63. """Extract the images into a 4D tensor [image index, y, x, channels].
  64. Values are rescaled from [0, 255] down to [-0.5, 0.5].
  65. """
  66. print('Extracting', filename)
  67. with gzip.open(filename) as bytestream:
  68. bytestream.read(16)
  69. buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
  70. data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
  71. data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
  72. data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
  73. return data
  74. def extract_labels(filename, num_images):
  75. """Extract the labels into a vector of int64 label IDs."""
  76. print('Extracting', filename)
  77. with gzip.open(filename) as bytestream:
  78. bytestream.read(8)
  79. buf = bytestream.read(1 * num_images)
  80. labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
  81. return labels
  82. def fake_data(num_images):
  83. """Generate a fake dataset that matches the dimensions of MNIST."""
  84. data = numpy.ndarray(
  85. shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
  86. dtype=numpy.float32)
  87. labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64)
  88. for image in xrange(num_images):
  89. label = image % 2
  90. data[image, :, :, 0] = label - 0.5
  91. labels[image] = label
  92. return data, labels
  93. def error_rate(predictions, labels):
  94. """Return the error rate based on dense predictions and sparse labels."""
  95. return 100.0 - (
  96. 100.0 *
  97. numpy.sum(numpy.argmax(predictions, 1) == labels) /
  98. predictions.shape[0])
  99. def main(_):
  100. if FLAGS.self_test:
  101. print('Running self-test.')
  102. train_data, train_labels = fake_data(256)
  103. validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
  104. test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
  105. num_epochs = 1
  106. else:
  107. # Get the data.
  108. train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
  109. train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
  110. test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
  111. test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
  112. # Extract it into numpy arrays.
  113. train_data = extract_data(train_data_filename, 60000)
  114. train_labels = extract_labels(train_labels_filename, 60000)
  115. test_data = extract_data(test_data_filename, 10000)
  116. test_labels = extract_labels(test_labels_filename, 10000)
  117. # Generate a validation set.
  118. validation_data = train_data[:VALIDATION_SIZE, ...]
  119. validation_labels = train_labels[:VALIDATION_SIZE]
  120. train_data = train_data[VALIDATION_SIZE:, ...]
  121. train_labels = train_labels[VALIDATION_SIZE:]
  122. num_epochs = NUM_EPOCHS
  123. train_size = train_labels.shape[0]
  124. # This is where training samples and labels are fed to the graph.
  125. # These placeholder nodes will be fed a batch of training data at each
  126. # training step using the {feed_dict} argument to the Run() call below.
  127. train_data_node = tf.placeholder(
  128. data_type(),
  129. shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
  130. train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
  131. eval_data = tf.placeholder(
  132. data_type(),
  133. shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
  134. # The variables below hold all the trainable weights. They are passed an
  135. # initial value which will be assigned when we call:
  136. # {tf.global_variables_initializer().run()}
  137. conv1_weights = tf.Variable(
  138. tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32.
  139. stddev=0.1,
  140. seed=SEED, dtype=data_type()))
  141. conv1_biases = tf.Variable(tf.zeros([32], dtype=data_type()))
  142. conv2_weights = tf.Variable(tf.truncated_normal(
  143. [5, 5, 32, 64], stddev=0.1,
  144. seed=SEED, dtype=data_type()))
  145. conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=data_type()))
  146. fc1_weights = tf.Variable( # fully connected, depth 512.
  147. tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
  148. stddev=0.1,
  149. seed=SEED,
  150. dtype=data_type()))
  151. fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=data_type()))
  152. fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS],
  153. stddev=0.1,
  154. seed=SEED,
  155. dtype=data_type()))
  156. fc2_biases = tf.Variable(tf.constant(
  157. 0.1, shape=[NUM_LABELS], dtype=data_type()))
  158. # We will replicate the model structure for the training subgraph, as well
  159. # as the evaluation subgraphs, while sharing the trainable parameters.
  160. def model(data, train=False):
  161. """The Model definition."""
  162. # 2D convolution, with 'SAME' padding (i.e. the output feature map has
  163. # the same size as the input). Note that {strides} is a 4D array whose
  164. # shape matches the data layout: [image index, y, x, depth].
  165. conv = tf.nn.conv2d(data,
  166. conv1_weights,
  167. strides=[1, 1, 1, 1],
  168. padding='SAME')
  169. # Bias and rectified linear non-linearity.
  170. relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
  171. # Max pooling. The kernel size spec {ksize} also follows the layout of
  172. # the data. Here we have a pooling window of 2, and a stride of 2.
  173. pool = tf.nn.max_pool(relu,
  174. ksize=[1, 2, 2, 1],
  175. strides=[1, 2, 2, 1],
  176. padding='SAME')
  177. conv = tf.nn.conv2d(pool,
  178. conv2_weights,
  179. strides=[1, 1, 1, 1],
  180. padding='SAME')
  181. relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
  182. pool = tf.nn.max_pool(relu,
  183. ksize=[1, 2, 2, 1],
  184. strides=[1, 2, 2, 1],
  185. padding='SAME')
  186. # Reshape the feature map cuboid into a 2D matrix to feed it to the
  187. # fully connected layers.
  188. pool_shape = pool.get_shape().as_list()
  189. reshape = tf.reshape(
  190. pool,
  191. [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
  192. # Fully connected layer. Note that the '+' operation automatically
  193. # broadcasts the biases.
  194. hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
  195. # Add a 50% dropout during training only. Dropout also scales
  196. # activations such that no rescaling is needed at evaluation time.
  197. if train:
  198. hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
  199. return tf.matmul(hidden, fc2_weights) + fc2_biases
  200. # Training computation: logits + cross-entropy loss.
  201. logits = model(train_data_node, True)
  202. loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  203. labels=train_labels_node, logits=logits))
  204. # L2 regularization for the fully connected parameters.
  205. regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
  206. tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
  207. # Add the regularization term to the loss.
  208. loss += 5e-4 * regularizers
  209. # Optimizer: set up a variable that's incremented once per batch and
  210. # controls the learning rate decay.
  211. batch = tf.Variable(0, dtype=data_type())
  212. # Decay once per epoch, using an exponential schedule starting at 0.01.
  213. learning_rate = tf.train.exponential_decay(
  214. 0.01, # Base learning rate.
  215. batch * BATCH_SIZE, # Current index into the dataset.
  216. train_size, # Decay step.
  217. 0.95, # Decay rate.
  218. staircase=True)
  219. # Use simple momentum for the optimization.
  220. optimizer = tf.train.MomentumOptimizer(learning_rate,
  221. 0.9).minimize(loss,
  222. global_step=batch)
  223. # Predictions for the current training minibatch.
  224. train_prediction = tf.nn.softmax(logits)
  225. # Predictions for the test and validation, which we'll compute less often.
  226. eval_prediction = tf.nn.softmax(model(eval_data))
  227. # Small utility function to evaluate a dataset by feeding batches of data to
  228. # {eval_data} and pulling the results from {eval_predictions}.
  229. # Saves memory and enables this to run on smaller GPUs.
  230. def eval_in_batches(data, sess):
  231. """Get all predictions for a dataset by running it in small batches."""
  232. size = data.shape[0]
  233. if size < EVAL_BATCH_SIZE:
  234. raise ValueError("batch size for evals larger than dataset: %d" % size)
  235. predictions = numpy.ndarray(shape=(size, NUM_LABELS), dtype=numpy.float32)
  236. for begin in xrange(0, size, EVAL_BATCH_SIZE):
  237. end = begin + EVAL_BATCH_SIZE
  238. if end <= size:
  239. predictions[begin:end, :] = sess.run(
  240. eval_prediction,
  241. feed_dict={eval_data: data[begin:end, ...]})
  242. else:
  243. batch_predictions = sess.run(
  244. eval_prediction,
  245. feed_dict={eval_data: data[-EVAL_BATCH_SIZE:, ...]})
  246. predictions[begin:, :] = batch_predictions[begin - size:, :]
  247. return predictions
  248. # Create a local session to run the training.
  249. start_time = time.time()
  250. with tf.Session() as sess:
  251. # Run all the initializers to prepare the trainable parameters.
  252. tf.global_variables_initializer().run()
  253. print('Initialized!')
  254. # Loop through training steps.
  255. for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
  256. # Compute the offset of the current minibatch in the data.
  257. # Note that we could use better randomization across epochs.
  258. offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
  259. batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
  260. batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
  261. # This dictionary maps the batch data (as a numpy array) to the
  262. # node in the graph it should be fed to.
  263. feed_dict = {train_data_node: batch_data,
  264. train_labels_node: batch_labels}
  265. # Run the optimizer to update weights.
  266. sess.run(optimizer, feed_dict=feed_dict)
  267. # print some extra information once reach the evaluation frequency
  268. if step % EVAL_FREQUENCY == 0:
  269. # fetch some extra nodes' data
  270. l, lr, predictions = sess.run([loss, learning_rate, train_prediction],
  271. feed_dict=feed_dict)
  272. elapsed_time = time.time() - start_time
  273. start_time = time.time()
  274. print('Step %d (epoch %.2f), %.1f ms' %
  275. (step, float(step) * BATCH_SIZE / train_size,
  276. 1000 * elapsed_time / EVAL_FREQUENCY))
  277. print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
  278. print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels))
  279. print('Validation error: %.1f%%' % error_rate(
  280. eval_in_batches(validation_data, sess), validation_labels))
  281. sys.stdout.flush()
  282. # Finally print the result!
  283. test_error = error_rate(eval_in_batches(test_data, sess), test_labels)
  284. print('Test error: %.1f%%' % test_error)
  285. if FLAGS.self_test:
  286. print('test_error', test_error)
  287. assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
  288. test_error,)
  289. if __name__ == '__main__':
  290. parser = argparse.ArgumentParser()
  291. parser.add_argument(
  292. '--use_fp16',
  293. default=False,
  294. help='Use half floats instead of full floats if True.',
  295. action='store_true')
  296. parser.add_argument(
  297. '--self_test',
  298. default=False,
  299. action='store_true',
  300. help='True if running a self test.')
  301. FLAGS, unparsed = parser.parse_known_args()
  302. tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)