train_image_classifier.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generic training script that trains a model using a given dataset."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.python.ops import control_flow_ops
  21. from datasets import dataset_factory
  22. from deployment import model_deploy
  23. from nets import nets_factory
  24. from preprocessing import preprocessing_factory
  25. slim = tf.contrib.slim
  26. tf.app.flags.DEFINE_string(
  27. 'master', '', 'The address of the TensorFlow master to use.')
  28. tf.app.flags.DEFINE_string(
  29. 'train_dir', '/tmp/tfmodel/',
  30. 'Directory where checkpoints and event logs are written to.')
  31. tf.app.flags.DEFINE_integer('num_clones', 1,
  32. 'Number of model clones to deploy.')
  33. tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
  34. 'Use CPUs to deploy clones.')
  35. tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
  36. tf.app.flags.DEFINE_integer(
  37. 'num_ps_tasks', 0,
  38. 'The number of parameter servers. If the value is 0, then the parameters '
  39. 'are handled locally by the worker.')
  40. tf.app.flags.DEFINE_integer(
  41. 'num_readers', 4,
  42. 'The number of parallel readers that read data from the dataset.')
  43. tf.app.flags.DEFINE_integer(
  44. 'num_preprocessing_threads', 4,
  45. 'The number of threads used to create the batches.')
  46. tf.app.flags.DEFINE_integer(
  47. 'log_every_n_steps', 10,
  48. 'The frequency with which logs are print.')
  49. tf.app.flags.DEFINE_integer(
  50. 'save_summaries_secs', 600,
  51. 'The frequency with which summaries are saved, in seconds.')
  52. tf.app.flags.DEFINE_integer(
  53. 'save_interval_secs', 600,
  54. 'The frequency with which the model is saved, in seconds.')
  55. tf.app.flags.DEFINE_integer(
  56. 'task', 0, 'Task id of the replica running the training.')
  57. ######################
  58. # Optimization Flags #
  59. ######################
  60. tf.app.flags.DEFINE_float(
  61. 'weight_decay', 0.00004, 'The weight decay on the model weights.')
  62. tf.app.flags.DEFINE_string(
  63. 'optimizer', 'rmsprop',
  64. 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
  65. '"ftrl", "momentum", "sgd" or "rmsprop".')
  66. tf.app.flags.DEFINE_float(
  67. 'adadelta_rho', 0.95,
  68. 'The decay rate for adadelta.')
  69. tf.app.flags.DEFINE_float(
  70. 'adagrad_initial_accumulator_value', 0.1,
  71. 'Starting value for the AdaGrad accumulators.')
  72. tf.app.flags.DEFINE_float(
  73. 'adam_beta1', 0.9,
  74. 'The exponential decay rate for the 1st moment estimates.')
  75. tf.app.flags.DEFINE_float(
  76. 'adam_beta2', 0.999,
  77. 'The exponential decay rate for the 2nd moment estimates.')
  78. tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
  79. tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
  80. 'The learning rate power.')
  81. tf.app.flags.DEFINE_float(
  82. 'ftrl_initial_accumulator_value', 0.1,
  83. 'Starting value for the FTRL accumulators.')
  84. tf.app.flags.DEFINE_float(
  85. 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
  86. tf.app.flags.DEFINE_float(
  87. 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
  88. tf.app.flags.DEFINE_float(
  89. 'momentum', 0.9,
  90. 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
  91. tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
  92. tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
  93. #######################
  94. # Learning Rate Flags #
  95. #######################
  96. tf.app.flags.DEFINE_string(
  97. 'learning_rate_decay_type',
  98. 'exponential',
  99. 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
  100. ' or "polynomial"')
  101. tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
  102. tf.app.flags.DEFINE_float(
  103. 'end_learning_rate', 0.0001,
  104. 'The minimal end learning rate used by a polynomial decay learning rate.')
  105. tf.app.flags.DEFINE_float(
  106. 'label_smoothing', 0.0, 'The amount of label smoothing.')
  107. tf.app.flags.DEFINE_float(
  108. 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
  109. tf.app.flags.DEFINE_float(
  110. 'num_epochs_per_decay', 2.0,
  111. 'Number of epochs after which learning rate decays.')
  112. tf.app.flags.DEFINE_bool(
  113. 'sync_replicas', False,
  114. 'Whether or not to synchronize the replicas during training.')
  115. tf.app.flags.DEFINE_integer(
  116. 'replicas_to_aggregate', 1,
  117. 'The Number of gradients to collect before updating params.')
  118. tf.app.flags.DEFINE_float(
  119. 'moving_average_decay', None,
  120. 'The decay to use for the moving average.'
  121. 'If left as None, then moving averages are not used.')
  122. #######################
  123. # Dataset Flags #
  124. #######################
  125. tf.app.flags.DEFINE_string(
  126. 'dataset_name', 'imagenet', 'The name of the dataset to load.')
  127. tf.app.flags.DEFINE_string(
  128. 'dataset_split_name', 'train', 'The name of the train/test split.')
  129. tf.app.flags.DEFINE_string(
  130. 'dataset_dir', None, 'The directory where the dataset files are stored.')
  131. tf.app.flags.DEFINE_integer(
  132. 'labels_offset', 0,
  133. 'An offset for the labels in the dataset. This flag is primarily used to '
  134. 'evaluate the VGG and ResNet architectures which do not use a background '
  135. 'class for the ImageNet dataset.')
  136. tf.app.flags.DEFINE_string(
  137. 'model_name', 'inception_v3', 'The name of the architecture to train.')
  138. tf.app.flags.DEFINE_string(
  139. 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  140. 'as `None`, then the model_name flag is used.')
  141. tf.app.flags.DEFINE_integer(
  142. 'batch_size', 32, 'The number of samples in each batch.')
  143. tf.app.flags.DEFINE_integer(
  144. 'train_image_size', None, 'Train image size')
  145. tf.app.flags.DEFINE_integer('max_number_of_steps', None,
  146. 'The maximum number of training steps.')
  147. #####################
  148. # Fine-Tuning Flags #
  149. #####################
  150. tf.app.flags.DEFINE_string(
  151. 'checkpoint_path', None,
  152. 'The path to a checkpoint from which to fine-tune.')
  153. tf.app.flags.DEFINE_string(
  154. 'checkpoint_exclude_scopes', None,
  155. 'Comma-separated list of scopes of variables to exclude when restoring '
  156. 'from a checkpoint.')
  157. tf.app.flags.DEFINE_string(
  158. 'trainable_scopes', None,
  159. 'Comma-separated list of scopes to filter the set of variables to train.'
  160. 'By default, None would train all the variables.')
  161. tf.app.flags.DEFINE_boolean(
  162. 'ignore_missing_vars', False,
  163. 'When restoring a checkpoint would ignore missing variables.')
  164. FLAGS = tf.app.flags.FLAGS
  165. def _configure_learning_rate(num_samples_per_epoch, global_step):
  166. """Configures the learning rate.
  167. Args:
  168. num_samples_per_epoch: The number of samples in each epoch of training.
  169. global_step: The global_step tensor.
  170. Returns:
  171. A `Tensor` representing the learning rate.
  172. Raises:
  173. ValueError: if
  174. """
  175. decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
  176. FLAGS.num_epochs_per_decay)
  177. if FLAGS.sync_replicas:
  178. decay_steps /= FLAGS.replicas_to_aggregate
  179. if FLAGS.learning_rate_decay_type == 'exponential':
  180. return tf.train.exponential_decay(FLAGS.learning_rate,
  181. global_step,
  182. decay_steps,
  183. FLAGS.learning_rate_decay_factor,
  184. staircase=True,
  185. name='exponential_decay_learning_rate')
  186. elif FLAGS.learning_rate_decay_type == 'fixed':
  187. return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
  188. elif FLAGS.learning_rate_decay_type == 'polynomial':
  189. return tf.train.polynomial_decay(FLAGS.learning_rate,
  190. global_step,
  191. decay_steps,
  192. FLAGS.end_learning_rate,
  193. power=1.0,
  194. cycle=False,
  195. name='polynomial_decay_learning_rate')
  196. else:
  197. raise ValueError('learning_rate_decay_type [%s] was not recognized',
  198. FLAGS.learning_rate_decay_type)
  199. def _configure_optimizer(learning_rate):
  200. """Configures the optimizer used for training.
  201. Args:
  202. learning_rate: A scalar or `Tensor` learning rate.
  203. Returns:
  204. An instance of an optimizer.
  205. Raises:
  206. ValueError: if FLAGS.optimizer is not recognized.
  207. """
  208. if FLAGS.optimizer == 'adadelta':
  209. optimizer = tf.train.AdadeltaOptimizer(
  210. learning_rate,
  211. rho=FLAGS.adadelta_rho,
  212. epsilon=FLAGS.opt_epsilon)
  213. elif FLAGS.optimizer == 'adagrad':
  214. optimizer = tf.train.AdagradOptimizer(
  215. learning_rate,
  216. initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
  217. elif FLAGS.optimizer == 'adam':
  218. optimizer = tf.train.AdamOptimizer(
  219. learning_rate,
  220. beta1=FLAGS.adam_beta1,
  221. beta2=FLAGS.adam_beta2,
  222. epsilon=FLAGS.opt_epsilon)
  223. elif FLAGS.optimizer == 'ftrl':
  224. optimizer = tf.train.FtrlOptimizer(
  225. learning_rate,
  226. learning_rate_power=FLAGS.ftrl_learning_rate_power,
  227. initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
  228. l1_regularization_strength=FLAGS.ftrl_l1,
  229. l2_regularization_strength=FLAGS.ftrl_l2)
  230. elif FLAGS.optimizer == 'momentum':
  231. optimizer = tf.train.MomentumOptimizer(
  232. learning_rate,
  233. momentum=FLAGS.momentum,
  234. name='Momentum')
  235. elif FLAGS.optimizer == 'rmsprop':
  236. optimizer = tf.train.RMSPropOptimizer(
  237. learning_rate,
  238. decay=FLAGS.rmsprop_decay,
  239. momentum=FLAGS.rmsprop_momentum,
  240. epsilon=FLAGS.opt_epsilon)
  241. elif FLAGS.optimizer == 'sgd':
  242. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  243. else:
  244. raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
  245. return optimizer
  246. def _add_variables_summaries(learning_rate):
  247. summaries = []
  248. for variable in slim.get_model_variables():
  249. summaries.append(tf.summary.histogram(variable.op.name, variable))
  250. summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
  251. return summaries
  252. def _get_init_fn():
  253. """Returns a function run by the chief worker to warm-start the training.
  254. Note that the init_fn is only run when initializing the model during the very
  255. first global step.
  256. Returns:
  257. An init function run by the supervisor.
  258. """
  259. if FLAGS.checkpoint_path is None:
  260. return None
  261. # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  262. # ignoring the checkpoint anyway.
  263. if tf.train.latest_checkpoint(FLAGS.train_dir):
  264. tf.logging.info(
  265. 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
  266. % FLAGS.train_dir)
  267. return None
  268. exclusions = []
  269. if FLAGS.checkpoint_exclude_scopes:
  270. exclusions = [scope.strip()
  271. for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
  272. # TODO(sguada) variables.filter_variables()
  273. variables_to_restore = []
  274. for var in slim.get_model_variables():
  275. excluded = False
  276. for exclusion in exclusions:
  277. if var.op.name.startswith(exclusion):
  278. excluded = True
  279. break
  280. if not excluded:
  281. variables_to_restore.append(var)
  282. if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
  283. checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  284. else:
  285. checkpoint_path = FLAGS.checkpoint_path
  286. tf.logging.info('Fine-tuning from %s' % checkpoint_path)
  287. return slim.assign_from_checkpoint_fn(
  288. checkpoint_path,
  289. variables_to_restore,
  290. ignore_missing_vars=FLAGS.ignore_missing_vars)
  291. def _get_variables_to_train():
  292. """Returns a list of variables to train.
  293. Returns:
  294. A list of variables to train by the optimizer.
  295. """
  296. if FLAGS.trainable_scopes is None:
  297. return tf.trainable_variables()
  298. else:
  299. scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
  300. variables_to_train = []
  301. for scope in scopes:
  302. variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
  303. variables_to_train.extend(variables)
  304. return variables_to_train
  305. def main(_):
  306. if not FLAGS.dataset_dir:
  307. raise ValueError('You must supply the dataset directory with --dataset_dir')
  308. tf.logging.set_verbosity(tf.logging.INFO)
  309. with tf.Graph().as_default():
  310. #######################
  311. # Config model_deploy #
  312. #######################
  313. deploy_config = model_deploy.DeploymentConfig(
  314. num_clones=FLAGS.num_clones,
  315. clone_on_cpu=FLAGS.clone_on_cpu,
  316. replica_id=FLAGS.task,
  317. num_replicas=FLAGS.worker_replicas,
  318. num_ps_tasks=FLAGS.num_ps_tasks)
  319. # Create global_step
  320. with tf.device(deploy_config.variables_device()):
  321. global_step = slim.create_global_step()
  322. ######################
  323. # Select the dataset #
  324. ######################
  325. dataset = dataset_factory.get_dataset(
  326. FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
  327. ######################
  328. # Select the network #
  329. ######################
  330. network_fn = nets_factory.get_network_fn(
  331. FLAGS.model_name,
  332. num_classes=(dataset.num_classes - FLAGS.labels_offset),
  333. weight_decay=FLAGS.weight_decay,
  334. is_training=True)
  335. #####################################
  336. # Select the preprocessing function #
  337. #####################################
  338. preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
  339. image_preprocessing_fn = preprocessing_factory.get_preprocessing(
  340. preprocessing_name,
  341. is_training=True)
  342. ##############################################################
  343. # Create a dataset provider that loads data from the dataset #
  344. ##############################################################
  345. with tf.device(deploy_config.inputs_device()):
  346. provider = slim.dataset_data_provider.DatasetDataProvider(
  347. dataset,
  348. num_readers=FLAGS.num_readers,
  349. common_queue_capacity=20 * FLAGS.batch_size,
  350. common_queue_min=10 * FLAGS.batch_size)
  351. [image, label] = provider.get(['image', 'label'])
  352. label -= FLAGS.labels_offset
  353. train_image_size = FLAGS.train_image_size or network_fn.default_image_size
  354. image = image_preprocessing_fn(image, train_image_size, train_image_size)
  355. images, labels = tf.train.batch(
  356. [image, label],
  357. batch_size=FLAGS.batch_size,
  358. num_threads=FLAGS.num_preprocessing_threads,
  359. capacity=5 * FLAGS.batch_size)
  360. labels = slim.one_hot_encoding(
  361. labels, dataset.num_classes - FLAGS.labels_offset)
  362. batch_queue = slim.prefetch_queue.prefetch_queue(
  363. [images, labels], capacity=2 * deploy_config.num_clones)
  364. ####################
  365. # Define the model #
  366. ####################
  367. def clone_fn(batch_queue):
  368. """Allows data parallelism by creating multiple clones of network_fn."""
  369. images, labels = batch_queue.dequeue()
  370. logits, end_points = network_fn(images)
  371. #############################
  372. # Specify the loss function #
  373. #############################
  374. if 'AuxLogits' in end_points:
  375. tf.losses.softmax_cross_entropy(
  376. logits=end_points['AuxLogits'], onehot_labels=labels,
  377. label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
  378. tf.losses.softmax_cross_entropy(
  379. logits=logits, onehot_labels=labels,
  380. label_smoothing=FLAGS.label_smoothing, weights=1.0)
  381. return end_points
  382. # Gather initial summaries.
  383. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
  384. clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
  385. first_clone_scope = deploy_config.clone_scope(0)
  386. # Gather update_ops from the first clone. These contain, for example,
  387. # the updates for the batch_norm variables created by network_fn.
  388. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
  389. # Add summaries for end_points.
  390. end_points = clones[0].outputs
  391. for end_point in end_points:
  392. x = end_points[end_point]
  393. summaries.add(tf.summary.histogram('activations/' + end_point, x))
  394. summaries.add(tf.summary.scalar('sparsity/' + end_point,
  395. tf.nn.zero_fraction(x)))
  396. # Add summaries for losses.
  397. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
  398. summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
  399. # Add summaries for variables.
  400. for variable in slim.get_model_variables():
  401. summaries.add(tf.summary.histogram(variable.op.name, variable))
  402. #################################
  403. # Configure the moving averages #
  404. #################################
  405. if FLAGS.moving_average_decay:
  406. moving_average_variables = slim.get_model_variables()
  407. variable_averages = tf.train.ExponentialMovingAverage(
  408. FLAGS.moving_average_decay, global_step)
  409. else:
  410. moving_average_variables, variable_averages = None, None
  411. #########################################
  412. # Configure the optimization procedure. #
  413. #########################################
  414. with tf.device(deploy_config.optimizer_device()):
  415. learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
  416. optimizer = _configure_optimizer(learning_rate)
  417. summaries.add(tf.summary.scalar('learning_rate', learning_rate))
  418. if FLAGS.sync_replicas:
  419. # If sync_replicas is enabled, the averaging will be done in the chief
  420. # queue runner.
  421. optimizer = tf.train.SyncReplicasOptimizer(
  422. opt=optimizer,
  423. replicas_to_aggregate=FLAGS.replicas_to_aggregate,
  424. variable_averages=variable_averages,
  425. variables_to_average=moving_average_variables,
  426. replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
  427. total_num_replicas=FLAGS.worker_replicas)
  428. elif FLAGS.moving_average_decay:
  429. # Update ops executed locally by trainer.
  430. update_ops.append(variable_averages.apply(moving_average_variables))
  431. # Variables to train.
  432. variables_to_train = _get_variables_to_train()
  433. # and returns a train_tensor and summary_op
  434. total_loss, clones_gradients = model_deploy.optimize_clones(
  435. clones,
  436. optimizer,
  437. var_list=variables_to_train)
  438. # Add total_loss to summary.
  439. summaries.add(tf.summary.scalar('total_loss', total_loss))
  440. # Create gradient updates.
  441. grad_updates = optimizer.apply_gradients(clones_gradients,
  442. global_step=global_step)
  443. update_ops.append(grad_updates)
  444. update_op = tf.group(*update_ops)
  445. train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
  446. name='train_op')
  447. # Add the summaries from the first clone. These contain the summaries
  448. # created by model_fn and either optimize_clones() or _gather_clone_loss().
  449. summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
  450. first_clone_scope))
  451. # Merge all summaries together.
  452. summary_op = tf.summary.merge(list(summaries), name='summary_op')
  453. ###########################
  454. # Kicks off the training. #
  455. ###########################
  456. slim.learning.train(
  457. train_tensor,
  458. logdir=FLAGS.train_dir,
  459. master=FLAGS.master,
  460. is_chief=(FLAGS.task == 0),
  461. init_fn=_get_init_fn(),
  462. summary_op=summary_op,
  463. number_of_steps=FLAGS.max_number_of_steps,
  464. log_every_n_steps=FLAGS.log_every_n_steps,
  465. save_summaries_secs=FLAGS.save_summaries_secs,
  466. save_interval_secs=FLAGS.save_interval_secs,
  467. sync_optimizer=optimizer if FLAGS.sync_replicas else None)
  468. if __name__ == '__main__':
  469. tf.app.run()