train.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generic training script that trains a given model a specified dataset."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.python.ops import control_flow_ops
  21. from slim.datasets import dataset_factory
  22. from slim.models import model_deploy
  23. from slim.models import model_factory
  24. from slim.models import preprocessing_factory
  25. slim = tf.contrib.slim
  26. tf.app.flags.DEFINE_string(
  27. 'master', '', 'The address of the TensorFlow master to use.')
  28. tf.app.flags.DEFINE_string(
  29. 'train_dir', '/tmp/tfmodel/',
  30. 'Directory where checkpoints and event logs are written to.')
  31. tf.app.flags.DEFINE_integer('num_clones', 1,
  32. 'Number of model clones to deploy.')
  33. tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
  34. 'Use CPUs to deploy clones.')
  35. tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
  36. tf.app.flags.DEFINE_integer(
  37. 'num_ps_tasks', 0,
  38. 'The number of parameter servers. If the value is 0, then the parameters '
  39. 'are handled locally by the worker.')
  40. tf.app.flags.DEFINE_integer(
  41. 'num_readers', 4,
  42. 'The number of parallel readers that read data from the dataset.')
  43. tf.app.flags.DEFINE_integer(
  44. 'num_preprocessing_threads', 4,
  45. 'The number of threads used to create the batches.')
  46. tf.app.flags.DEFINE_integer(
  47. 'log_every_n_steps', 5,
  48. 'The frequency with which logs are print.')
  49. tf.app.flags.DEFINE_integer(
  50. 'save_summaries_secs', 600,
  51. 'The frequency with which summaries are saved, in seconds.')
  52. tf.app.flags.DEFINE_integer(
  53. 'save_interval_secs', 600,
  54. 'The frequency with which the model is saved, in seconds.')
  55. tf.app.flags.DEFINE_integer(
  56. 'task', 0, 'Task id of the replica running the training.')
  57. ######################
  58. # Optimization Flags #
  59. ######################
  60. tf.app.flags.DEFINE_float(
  61. 'weight_decay', 0.00004, 'The weight decay on the model weights.')
  62. tf.app.flags.DEFINE_string(
  63. 'optimizer', 'rmsprop',
  64. 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
  65. '"ftrl", "momentum", "sgd" or "rmsprop".')
  66. tf.app.flags.DEFINE_float(
  67. 'adadelta_rho', 0.95,
  68. 'The decay rate for adadelta.')
  69. tf.app.flags.DEFINE_float(
  70. 'adagrad_initial_accumulator_value', 0.1,
  71. 'Starting value for the AdaGrad accumulators.')
  72. tf.app.flags.DEFINE_float(
  73. 'adam_beta1', 0.9,
  74. 'The exponential decay rate for the 1st moment estimates.')
  75. tf.app.flags.DEFINE_float(
  76. 'adam_beta2', 0.999,
  77. 'The exponential decay rate for the 2nd moment estimates.')
  78. tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
  79. tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
  80. 'The learning rate power.')
  81. tf.app.flags.DEFINE_float(
  82. 'ftrl_initial_accumulator_value', 0.1,
  83. 'Starting value for the FTRL accumulators.')
  84. tf.app.flags.DEFINE_float(
  85. 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
  86. tf.app.flags.DEFINE_float(
  87. 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
  88. tf.app.flags.DEFINE_float(
  89. 'momentum', 0.9,
  90. 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
  91. tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
  92. tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
  93. #######################
  94. # Learning Rate Flags #
  95. #######################
  96. tf.app.flags.DEFINE_string(
  97. 'learning_rate_decay_type',
  98. 'exponential',
  99. 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
  100. ' or "polynomial"')
  101. tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
  102. tf.app.flags.DEFINE_float(
  103. 'end_learning_rate', 0.0001,
  104. 'The minimal end learning rate used by a polynomial decay learning rate.')
  105. tf.app.flags.DEFINE_float(
  106. 'label_smoothing', 0.0, 'The amount of label smoothing.')
  107. tf.app.flags.DEFINE_float(
  108. 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
  109. tf.app.flags.DEFINE_float(
  110. 'num_epochs_per_decay', 2.0,
  111. 'Number of epochs after which learning rate decays.')
  112. tf.app.flags.DEFINE_bool(
  113. 'sync_replicas', False,
  114. 'Whether or not to synchronize the replicas during training.')
  115. tf.app.flags.DEFINE_integer(
  116. 'replicas_to_aggregate', 1,
  117. 'The Number of gradients to collect before updating params.')
  118. tf.app.flags.DEFINE_float(
  119. 'moving_average_decay', None,
  120. 'The decay to use for the moving average.'
  121. 'If left as None, then moving averages are not used.')
  122. #######################
  123. # Dataset Flags #
  124. #######################
  125. tf.app.flags.DEFINE_string(
  126. 'dataset_name', 'imagenet', 'The name of the dataset to load.')
  127. tf.app.flags.DEFINE_string(
  128. 'dataset_split_name', 'train', 'The name of the train/test split.')
  129. tf.app.flags.DEFINE_string(
  130. 'dataset_dir', None, 'The directory where the dataset files are stored.')
  131. tf.app.flags.DEFINE_integer(
  132. 'labels_offset', 0,
  133. 'An offset for the labels in the dataset. This flag is primarily used to '
  134. 'evaluate the VGG and ResNet architectures which do not use a background '
  135. 'class for the ImageNet dataset.')
  136. tf.app.flags.DEFINE_string(
  137. 'model_name', 'inception_v3', 'The name of the architecture to train.')
  138. tf.app.flags.DEFINE_string(
  139. 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  140. 'as `None`, then the model_name flag is used.')
  141. tf.app.flags.DEFINE_integer(
  142. 'batch_size', 32, 'The number of samples in each batch.')
  143. tf.app.flags.DEFINE_integer(
  144. 'train_image_size', None, 'Train image size')
  145. tf.app.flags.DEFINE_integer('max_number_of_steps', None,
  146. 'The maximum number of training steps.')
  147. #####################
  148. # Fine-Tuning Flags #
  149. #####################
  150. tf.app.flags.DEFINE_string(
  151. 'checkpoint_path', None,
  152. 'The path to a checkpoint from which to fine-tune.')
  153. tf.app.flags.DEFINE_string(
  154. 'checkpoint_exclude_scopes', None,
  155. 'Comma-separated list of scopes to include when fine-tuning '
  156. 'from a checkpoint.')
  157. FLAGS = tf.app.flags.FLAGS
  158. def _configure_learning_rate(num_samples_per_epoch, global_step):
  159. """Configures the learning rate.
  160. Args:
  161. num_samples_per_epoch: The number of samples in each epoch of training.
  162. global_step: The global_step tensor.
  163. Returns:
  164. A `Tensor` representing the learning rate.
  165. Raises:
  166. ValueError: if
  167. """
  168. decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
  169. FLAGS.num_epochs_per_decay)
  170. if FLAGS.sync_replicas:
  171. decay_steps /= FLAGS.replicas_to_aggregate
  172. if FLAGS.learning_rate_decay_type == 'exponential':
  173. return tf.train.exponential_decay(FLAGS.learning_rate,
  174. global_step,
  175. decay_steps,
  176. FLAGS.learning_rate_decay_factor,
  177. staircase=True,
  178. name='exponential_decay_learning_rate')
  179. elif FLAGS.learning_rate_decay_type == 'fixed':
  180. return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
  181. elif FLAGS.learning_rate_decay_type == 'polynomial':
  182. return tf.train.polynomial_decay(FLAGS.learning_rate,
  183. global_step,
  184. decay_steps,
  185. FLAGS.end_learning_rate,
  186. power=1.0,
  187. cycle=False,
  188. name='polynomial_decay_learning_rate')
  189. else:
  190. raise ValueError('learning_rate_decay_type [%s] was not recognized',
  191. FLAGS.learning_rate_decay_type)
  192. def _configure_optimizer(learning_rate):
  193. """Configures the optimizer used for training.
  194. Args:
  195. learning_rate: A scalar or `Tensor` learning rate.
  196. Returns:
  197. An instance of an optimizer.
  198. Raises:
  199. ValueError: if FLAGS.optimizer is not recognized.
  200. """
  201. if FLAGS.optimizer == 'adadelta':
  202. optimizer = tf.train.AdadeltaOptimizer(
  203. learning_rate,
  204. rho=FLAGS.adadelta_rho,
  205. epsilon=FLAGS.opt_epsilon)
  206. elif FLAGS.optimizer == 'adagrad':
  207. optimizer = tf.train.AdagradOptimizer(
  208. learning_rate,
  209. initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
  210. elif FLAGS.optimizer == 'adam':
  211. optimizer = tf.train.AdamOptimizer(
  212. learning_rate,
  213. beta1=FLAGS.adam_beta1,
  214. beta2=FLAGS.adam_beta2,
  215. epsilon=FLAGS.opt_epsilon)
  216. elif FLAGS.optimizer == 'ftrl':
  217. optimizer = tf.train.FtrlOptimizer(
  218. learning_rate,
  219. learning_rate_power=FLAGS.ftrl_learning_rate_power,
  220. initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
  221. l1_regularization_strength=FLAGS.ftrl_l1,
  222. l2_regularization_strength=FLAGS.ftrl_l2)
  223. elif FLAGS.optimizer == 'momentum':
  224. optimizer = tf.train.MomentumOptimizer(
  225. learning_rate,
  226. momentum=FLAGS.momentum,
  227. name='Momentum')
  228. elif FLAGS.optimizer == 'rmsprop':
  229. optimizer = tf.train.RMSPropOptimizer(
  230. learning_rate,
  231. decay=FLAGS.rmsprop_decay,
  232. momentum=FLAGS.rmsprop_momentum,
  233. epsilon=FLAGS.opt_epsilon)
  234. elif FLAGS.optimizer == 'sgd':
  235. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  236. else:
  237. raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
  238. return optimizer
  239. def _add_variables_summaries(learning_rate):
  240. summaries = []
  241. for variable in slim.get_model_variables():
  242. summaries.append(tf.histogram_summary(variable.op.name, variable))
  243. summaries.append(tf.scalar_summary('training/Learning Rate', learning_rate))
  244. return summaries
  245. def _get_init_fn():
  246. """Returns a function run by the chief worker to warm-start the training.
  247. Note that the init_fn is only run when initializing the model during the very
  248. first global step.
  249. Returns:
  250. An init function run by the supervisor.
  251. """
  252. if FLAGS.checkpoint_path is None:
  253. return None
  254. # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  255. # ignoring the checkpoint anyway.
  256. if tf.train.latest_checkpoint(FLAGS.train_dir):
  257. tf.logging.info(
  258. 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
  259. % FLAGS.train_dir)
  260. return None
  261. exclusions = []
  262. if FLAGS.checkpoint_exclude_scopes:
  263. exclusions = [scope.strip()
  264. for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
  265. # TODO(sguada) variables.filter_variables()
  266. variables_to_restore = []
  267. for var in slim.get_model_variables():
  268. excluded = False
  269. for exclusion in exclusions:
  270. if var.op.name.startswith(exclusion):
  271. excluded = True
  272. break
  273. if not excluded:
  274. variables_to_restore.append(var)
  275. return slim.assign_from_checkpoint_fn(
  276. FLAGS.checkpoint_path,
  277. variables_to_restore)
  278. def main(_):
  279. if not FLAGS.dataset_dir:
  280. raise ValueError('You must supply the dataset directory with --dataset_dir')
  281. with tf.Graph().as_default():
  282. ######################
  283. # Config model_deploy#
  284. ######################
  285. deploy_config = model_deploy.DeploymentConfig(
  286. num_clones=FLAGS.num_clones,
  287. clone_on_cpu=FLAGS.clone_on_cpu,
  288. replica_id=FLAGS.task,
  289. num_replicas=FLAGS.worker_replicas,
  290. num_ps_tasks=FLAGS.num_ps_tasks)
  291. # Create global_step
  292. with tf.device(deploy_config.variables_device()):
  293. global_step = slim.create_global_step()
  294. ######################
  295. # Select the dataset #
  296. ######################
  297. dataset = dataset_factory.get_dataset(
  298. FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
  299. ####################
  300. # Select the model #
  301. ####################
  302. model_fn = model_factory.get_model(
  303. FLAGS.model_name,
  304. num_classes=(dataset.num_classes - FLAGS.labels_offset),
  305. weight_decay=FLAGS.weight_decay,
  306. is_training=True)
  307. #####################################
  308. # Select the preprocessing function #
  309. #####################################
  310. preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
  311. image_preprocessing_fn = preprocessing_factory.get_preprocessing(
  312. preprocessing_name,
  313. is_training=True)
  314. ##############################################################
  315. # Create a dataset provider that loads data from the dataset #
  316. ##############################################################
  317. with tf.device(deploy_config.inputs_device()):
  318. provider = slim.dataset_data_provider.DatasetDataProvider(
  319. dataset,
  320. num_readers=FLAGS.num_readers,
  321. common_queue_capacity=20 * FLAGS.batch_size,
  322. common_queue_min=10 * FLAGS.batch_size)
  323. [image, label] = provider.get(['image', 'label'])
  324. label -= FLAGS.labels_offset
  325. if FLAGS.train_image_size is None:
  326. train_image_size = model_fn.default_image_size
  327. else:
  328. train_image_size = FLAGS.train_image_size
  329. image = image_preprocessing_fn(image, train_image_size, train_image_size)
  330. images, labels = tf.train.batch(
  331. [image, label],
  332. batch_size=FLAGS.batch_size,
  333. num_threads=FLAGS.num_preprocessing_threads,
  334. capacity=5 * FLAGS.batch_size)
  335. labels = slim.one_hot_encoding(
  336. labels, dataset.num_classes - FLAGS.labels_offset)
  337. batch_queue = slim.prefetch_queue.prefetch_queue(
  338. [images, labels], capacity=2 * deploy_config.num_clones)
  339. ####################
  340. # Define the model #
  341. ####################
  342. def clone_fn(batch_queue):
  343. """Allows data parallelism by creating multiple clones of the model_fn."""
  344. images, labels = batch_queue.dequeue()
  345. logits, end_points = model_fn(images)
  346. #############################
  347. # Specify the loss function #
  348. #############################
  349. if 'AuxLogits' in end_points:
  350. slim.losses.softmax_cross_entropy(
  351. end_points['AuxLogits'], labels,
  352. label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
  353. slim.losses.softmax_cross_entropy(
  354. logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
  355. # Gather initial summaries.
  356. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
  357. clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
  358. first_clone_scope = deploy_config.clone_scope(0)
  359. # Gather update_ops from the first clone. These contain, for example,
  360. # the updates for the batch_norm variables created by model_fn.
  361. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
  362. # Add summaries for losses.
  363. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
  364. tf.scalar_summary('losses/%s' % loss.op.name, loss)
  365. # Add summaries for variables.
  366. for variable in slim.get_model_variables():
  367. summaries.add(tf.histogram_summary(variable.op.name, variable))
  368. #################################
  369. # Configure the moving averages #
  370. #################################
  371. if FLAGS.moving_average_decay:
  372. moving_average_variables = slim.get_model_variables()
  373. variable_averages = tf.train.ExponentialMovingAverage(
  374. FLAGS.moving_average_decay, global_step)
  375. else:
  376. moving_average_variables, variable_averages = None, None
  377. #########################################
  378. # Configure the optimization procedure. #
  379. #########################################
  380. with tf.device(deploy_config.optimizer_device()):
  381. learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
  382. optimizer = _configure_optimizer(learning_rate)
  383. summaries.add(tf.scalar_summary('learning_rate', learning_rate,
  384. name='learning_rate'))
  385. if FLAGS.sync_replicas:
  386. # If sync_replicas is enabled, the averaging will be done in the chief
  387. # queue runner.
  388. optimizer = tf.train.SyncReplicasOptimizer(
  389. opt=optimizer,
  390. replicas_to_aggregate=FLAGS.replicas_to_aggregate,
  391. variable_averages=variable_averages,
  392. variables_to_average=moving_average_variables,
  393. replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
  394. total_num_replicas=FLAGS.worker_replicas)
  395. elif FLAGS.moving_average_decay:
  396. # Update ops executed locally by trainer.
  397. update_ops.append(variable_averages.apply(moving_average_variables))
  398. # TODO(sguada) Refactor into function that takes the clones and optimizer
  399. # and returns a train_tensor and summary_op
  400. total_loss, clones_gradients = model_deploy.optimize_clones(clones,
  401. optimizer)
  402. # Add total_loss to summary.
  403. summaries.add(tf.scalar_summary('total_loss', total_loss,
  404. name='total_loss'))
  405. # Create gradient updates.
  406. grad_updates = optimizer.apply_gradients(clones_gradients,
  407. global_step=global_step)
  408. update_ops.append(grad_updates)
  409. update_op = tf.group(*update_ops)
  410. train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
  411. name='train_op')
  412. # Add the summaries from the first clone. These contain the summaries
  413. # created by model_fn and either optimize_clones() or _gather_clone_loss().
  414. summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
  415. first_clone_scope))
  416. # Merge all summaries together.
  417. summary_op = tf.merge_summary(list(summaries), name='summary_op')
  418. ###########################
  419. # Kicks off the training. #
  420. ###########################
  421. slim.learning.train(
  422. train_tensor,
  423. logdir=FLAGS.train_dir,
  424. master=FLAGS.master,
  425. is_chief=(FLAGS.task == 0),
  426. init_fn=_get_init_fn(),
  427. summary_op=summary_op,
  428. number_of_steps=FLAGS.max_number_of_steps,
  429. log_every_n_steps=FLAGS.log_every_n_steps,
  430. save_summaries_secs=FLAGS.save_summaries_secs,
  431. save_interval_secs=FLAGS.save_interval_secs,
  432. sync_optimizer=optimizer if FLAGS.sync_replicas else None)
  433. if __name__ == '__main__':
  434. tf.app.run()