123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Generic training script that trains a model using a given dataset."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- from tensorflow.python.ops import control_flow_ops
- from datasets import dataset_factory
- from deployment import model_deploy
- from nets import nets_factory
- from preprocessing import preprocessing_factory
- slim = tf.contrib.slim
- tf.app.flags.DEFINE_string(
- 'master', '', 'The address of the TensorFlow master to use.')
- tf.app.flags.DEFINE_string(
- 'train_dir', '/tmp/tfmodel/',
- 'Directory where checkpoints and event logs are written to.')
- tf.app.flags.DEFINE_integer('num_clones', 1,
- 'Number of model clones to deploy.')
- tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
- 'Use CPUs to deploy clones.')
- tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
- tf.app.flags.DEFINE_integer(
- 'num_ps_tasks', 0,
- 'The number of parameter servers. If the value is 0, then the parameters '
- 'are handled locally by the worker.')
- tf.app.flags.DEFINE_integer(
- 'num_readers', 4,
- 'The number of parallel readers that read data from the dataset.')
- tf.app.flags.DEFINE_integer(
- 'num_preprocessing_threads', 4,
- 'The number of threads used to create the batches.')
- tf.app.flags.DEFINE_integer(
- 'log_every_n_steps', 10,
- 'The frequency with which logs are print.')
- tf.app.flags.DEFINE_integer(
- 'save_summaries_secs', 600,
- 'The frequency with which summaries are saved, in seconds.')
- tf.app.flags.DEFINE_integer(
- 'save_interval_secs', 600,
- 'The frequency with which the model is saved, in seconds.')
- tf.app.flags.DEFINE_integer(
- 'task', 0, 'Task id of the replica running the training.')
- ######################
- # Optimization Flags #
- ######################
- tf.app.flags.DEFINE_float(
- 'weight_decay', 0.00004, 'The weight decay on the model weights.')
- tf.app.flags.DEFINE_string(
- 'optimizer', 'rmsprop',
- 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
- '"ftrl", "momentum", "sgd" or "rmsprop".')
- tf.app.flags.DEFINE_float(
- 'adadelta_rho', 0.95,
- 'The decay rate for adadelta.')
- tf.app.flags.DEFINE_float(
- 'adagrad_initial_accumulator_value', 0.1,
- 'Starting value for the AdaGrad accumulators.')
- tf.app.flags.DEFINE_float(
- 'adam_beta1', 0.9,
- 'The exponential decay rate for the 1st moment estimates.')
- tf.app.flags.DEFINE_float(
- 'adam_beta2', 0.999,
- 'The exponential decay rate for the 2nd moment estimates.')
- tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
- tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
- 'The learning rate power.')
- tf.app.flags.DEFINE_float(
- 'ftrl_initial_accumulator_value', 0.1,
- 'Starting value for the FTRL accumulators.')
- tf.app.flags.DEFINE_float(
- 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
- tf.app.flags.DEFINE_float(
- 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
- tf.app.flags.DEFINE_float(
- 'momentum', 0.9,
- 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
- tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
- tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
- #######################
- # Learning Rate Flags #
- #######################
- tf.app.flags.DEFINE_string(
- 'learning_rate_decay_type',
- 'exponential',
- 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
- ' or "polynomial"')
- tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
- tf.app.flags.DEFINE_float(
- 'end_learning_rate', 0.0001,
- 'The minimal end learning rate used by a polynomial decay learning rate.')
- tf.app.flags.DEFINE_float(
- 'label_smoothing', 0.0, 'The amount of label smoothing.')
- tf.app.flags.DEFINE_float(
- 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
- tf.app.flags.DEFINE_float(
- 'num_epochs_per_decay', 2.0,
- 'Number of epochs after which learning rate decays.')
- tf.app.flags.DEFINE_bool(
- 'sync_replicas', False,
- 'Whether or not to synchronize the replicas during training.')
- tf.app.flags.DEFINE_integer(
- 'replicas_to_aggregate', 1,
- 'The Number of gradients to collect before updating params.')
- tf.app.flags.DEFINE_float(
- 'moving_average_decay', None,
- 'The decay to use for the moving average.'
- 'If left as None, then moving averages are not used.')
- #######################
- # Dataset Flags #
- #######################
- tf.app.flags.DEFINE_string(
- 'dataset_name', 'imagenet', 'The name of the dataset to load.')
- tf.app.flags.DEFINE_string(
- 'dataset_split_name', 'train', 'The name of the train/test split.')
- tf.app.flags.DEFINE_string(
- 'dataset_dir', None, 'The directory where the dataset files are stored.')
- tf.app.flags.DEFINE_integer(
- 'labels_offset', 0,
- 'An offset for the labels in the dataset. This flag is primarily used to '
- 'evaluate the VGG and ResNet architectures which do not use a background '
- 'class for the ImageNet dataset.')
- tf.app.flags.DEFINE_string(
- 'model_name', 'inception_v3', 'The name of the architecture to train.')
- tf.app.flags.DEFINE_string(
- 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
- 'as `None`, then the model_name flag is used.')
- tf.app.flags.DEFINE_integer(
- 'batch_size', 32, 'The number of samples in each batch.')
- tf.app.flags.DEFINE_integer(
- 'train_image_size', None, 'Train image size')
- tf.app.flags.DEFINE_integer('max_number_of_steps', None,
- 'The maximum number of training steps.')
- #####################
- # Fine-Tuning Flags #
- #####################
- tf.app.flags.DEFINE_string(
- 'checkpoint_path', None,
- 'The path to a checkpoint from which to fine-tune.')
- tf.app.flags.DEFINE_string(
- 'checkpoint_exclude_scopes', None,
- 'Comma-separated list of scopes of variables to exclude when restoring '
- 'from a checkpoint.')
- tf.app.flags.DEFINE_string(
- 'trainable_scopes', None,
- 'Comma-separated list of scopes to filter the set of variables to train.'
- 'By default, None would train all the variables.')
- tf.app.flags.DEFINE_boolean(
- 'ignore_missing_vars', False,
- 'When restoring a checkpoint would ignore missing variables.')
- FLAGS = tf.app.flags.FLAGS
- def _configure_learning_rate(num_samples_per_epoch, global_step):
- """Configures the learning rate.
- Args:
- num_samples_per_epoch: The number of samples in each epoch of training.
- global_step: The global_step tensor.
- Returns:
- A `Tensor` representing the learning rate.
- Raises:
- ValueError: if
- """
- decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
- FLAGS.num_epochs_per_decay)
- if FLAGS.sync_replicas:
- decay_steps /= FLAGS.replicas_to_aggregate
- if FLAGS.learning_rate_decay_type == 'exponential':
- return tf.train.exponential_decay(FLAGS.learning_rate,
- global_step,
- decay_steps,
- FLAGS.learning_rate_decay_factor,
- staircase=True,
- name='exponential_decay_learning_rate')
- elif FLAGS.learning_rate_decay_type == 'fixed':
- return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
- elif FLAGS.learning_rate_decay_type == 'polynomial':
- return tf.train.polynomial_decay(FLAGS.learning_rate,
- global_step,
- decay_steps,
- FLAGS.end_learning_rate,
- power=1.0,
- cycle=False,
- name='polynomial_decay_learning_rate')
- else:
- raise ValueError('learning_rate_decay_type [%s] was not recognized',
- FLAGS.learning_rate_decay_type)
- def _configure_optimizer(learning_rate):
- """Configures the optimizer used for training.
- Args:
- learning_rate: A scalar or `Tensor` learning rate.
- Returns:
- An instance of an optimizer.
- Raises:
- ValueError: if FLAGS.optimizer is not recognized.
- """
- if FLAGS.optimizer == 'adadelta':
- optimizer = tf.train.AdadeltaOptimizer(
- learning_rate,
- rho=FLAGS.adadelta_rho,
- epsilon=FLAGS.opt_epsilon)
- elif FLAGS.optimizer == 'adagrad':
- optimizer = tf.train.AdagradOptimizer(
- learning_rate,
- initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
- elif FLAGS.optimizer == 'adam':
- optimizer = tf.train.AdamOptimizer(
- learning_rate,
- beta1=FLAGS.adam_beta1,
- beta2=FLAGS.adam_beta2,
- epsilon=FLAGS.opt_epsilon)
- elif FLAGS.optimizer == 'ftrl':
- optimizer = tf.train.FtrlOptimizer(
- learning_rate,
- learning_rate_power=FLAGS.ftrl_learning_rate_power,
- initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
- l1_regularization_strength=FLAGS.ftrl_l1,
- l2_regularization_strength=FLAGS.ftrl_l2)
- elif FLAGS.optimizer == 'momentum':
- optimizer = tf.train.MomentumOptimizer(
- learning_rate,
- momentum=FLAGS.momentum,
- name='Momentum')
- elif FLAGS.optimizer == 'rmsprop':
- optimizer = tf.train.RMSPropOptimizer(
- learning_rate,
- decay=FLAGS.rmsprop_decay,
- momentum=FLAGS.rmsprop_momentum,
- epsilon=FLAGS.opt_epsilon)
- elif FLAGS.optimizer == 'sgd':
- optimizer = tf.train.GradientDescentOptimizer(learning_rate)
- else:
- raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
- return optimizer
- def _add_variables_summaries(learning_rate):
- summaries = []
- for variable in slim.get_model_variables():
- summaries.append(tf.summary.histogram(variable.op.name, variable))
- summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
- return summaries
- def _get_init_fn():
- """Returns a function run by the chief worker to warm-start the training.
- Note that the init_fn is only run when initializing the model during the very
- first global step.
- Returns:
- An init function run by the supervisor.
- """
- if FLAGS.checkpoint_path is None:
- return None
- # Warn the user if a checkpoint exists in the train_dir. Then we'll be
- # ignoring the checkpoint anyway.
- if tf.train.latest_checkpoint(FLAGS.train_dir):
- tf.logging.info(
- 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
- % FLAGS.train_dir)
- return None
- exclusions = []
- if FLAGS.checkpoint_exclude_scopes:
- exclusions = [scope.strip()
- for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
- # TODO(sguada) variables.filter_variables()
- variables_to_restore = []
- for var in slim.get_model_variables():
- excluded = False
- for exclusion in exclusions:
- if var.op.name.startswith(exclusion):
- excluded = True
- break
- if not excluded:
- variables_to_restore.append(var)
- if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
- checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
- else:
- checkpoint_path = FLAGS.checkpoint_path
- tf.logging.info('Fine-tuning from %s' % checkpoint_path)
- return slim.assign_from_checkpoint_fn(
- checkpoint_path,
- variables_to_restore,
- ignore_missing_vars=FLAGS.ignore_missing_vars)
- def _get_variables_to_train():
- """Returns a list of variables to train.
- Returns:
- A list of variables to train by the optimizer.
- """
- if FLAGS.trainable_scopes is None:
- return tf.trainable_variables()
- else:
- scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
- variables_to_train = []
- for scope in scopes:
- variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
- variables_to_train.extend(variables)
- return variables_to_train
- def main(_):
- if not FLAGS.dataset_dir:
- raise ValueError('You must supply the dataset directory with --dataset_dir')
- tf.logging.set_verbosity(tf.logging.INFO)
- with tf.Graph().as_default():
- #######################
- # Config model_deploy #
- #######################
- deploy_config = model_deploy.DeploymentConfig(
- num_clones=FLAGS.num_clones,
- clone_on_cpu=FLAGS.clone_on_cpu,
- replica_id=FLAGS.task,
- num_replicas=FLAGS.worker_replicas,
- num_ps_tasks=FLAGS.num_ps_tasks)
- # Create global_step
- with tf.device(deploy_config.variables_device()):
- global_step = slim.create_global_step()
- ######################
- # Select the dataset #
- ######################
- dataset = dataset_factory.get_dataset(
- FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
- ######################
- # Select the network #
- ######################
- network_fn = nets_factory.get_network_fn(
- FLAGS.model_name,
- num_classes=(dataset.num_classes - FLAGS.labels_offset),
- weight_decay=FLAGS.weight_decay,
- is_training=True)
- #####################################
- # Select the preprocessing function #
- #####################################
- preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
- image_preprocessing_fn = preprocessing_factory.get_preprocessing(
- preprocessing_name,
- is_training=True)
- ##############################################################
- # Create a dataset provider that loads data from the dataset #
- ##############################################################
- with tf.device(deploy_config.inputs_device()):
- provider = slim.dataset_data_provider.DatasetDataProvider(
- dataset,
- num_readers=FLAGS.num_readers,
- common_queue_capacity=20 * FLAGS.batch_size,
- common_queue_min=10 * FLAGS.batch_size)
- [image, label] = provider.get(['image', 'label'])
- label -= FLAGS.labels_offset
- train_image_size = FLAGS.train_image_size or network_fn.default_image_size
- image = image_preprocessing_fn(image, train_image_size, train_image_size)
- images, labels = tf.train.batch(
- [image, label],
- batch_size=FLAGS.batch_size,
- num_threads=FLAGS.num_preprocessing_threads,
- capacity=5 * FLAGS.batch_size)
- labels = slim.one_hot_encoding(
- labels, dataset.num_classes - FLAGS.labels_offset)
- batch_queue = slim.prefetch_queue.prefetch_queue(
- [images, labels], capacity=2 * deploy_config.num_clones)
- ####################
- # Define the model #
- ####################
- def clone_fn(batch_queue):
- """Allows data parallelism by creating multiple clones of network_fn."""
- images, labels = batch_queue.dequeue()
- logits, end_points = network_fn(images)
- #############################
- # Specify the loss function #
- #############################
- if 'AuxLogits' in end_points:
- tf.losses.softmax_cross_entropy(
- logits=end_points['AuxLogits'], onehot_labels=labels,
- label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
- tf.losses.softmax_cross_entropy(
- logits=logits, onehot_labels=labels,
- label_smoothing=FLAGS.label_smoothing, weights=1.0)
- return end_points
- # Gather initial summaries.
- summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
- clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
- first_clone_scope = deploy_config.clone_scope(0)
- # Gather update_ops from the first clone. These contain, for example,
- # the updates for the batch_norm variables created by network_fn.
- update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
- # Add summaries for end_points.
- end_points = clones[0].outputs
- for end_point in end_points:
- x = end_points[end_point]
- summaries.add(tf.summary.histogram('activations/' + end_point, x))
- summaries.add(tf.summary.scalar('sparsity/' + end_point,
- tf.nn.zero_fraction(x)))
- # Add summaries for losses.
- for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
- summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
- # Add summaries for variables.
- for variable in slim.get_model_variables():
- summaries.add(tf.summary.histogram(variable.op.name, variable))
- #################################
- # Configure the moving averages #
- #################################
- if FLAGS.moving_average_decay:
- moving_average_variables = slim.get_model_variables()
- variable_averages = tf.train.ExponentialMovingAverage(
- FLAGS.moving_average_decay, global_step)
- else:
- moving_average_variables, variable_averages = None, None
- #########################################
- # Configure the optimization procedure. #
- #########################################
- with tf.device(deploy_config.optimizer_device()):
- learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
- optimizer = _configure_optimizer(learning_rate)
- summaries.add(tf.summary.scalar('learning_rate', learning_rate))
- if FLAGS.sync_replicas:
- # If sync_replicas is enabled, the averaging will be done in the chief
- # queue runner.
- optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=FLAGS.replicas_to_aggregate,
- variable_averages=variable_averages,
- variables_to_average=moving_average_variables,
- replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
- total_num_replicas=FLAGS.worker_replicas)
- elif FLAGS.moving_average_decay:
- # Update ops executed locally by trainer.
- update_ops.append(variable_averages.apply(moving_average_variables))
- # Variables to train.
- variables_to_train = _get_variables_to_train()
- # and returns a train_tensor and summary_op
- total_loss, clones_gradients = model_deploy.optimize_clones(
- clones,
- optimizer,
- var_list=variables_to_train)
- # Add total_loss to summary.
- summaries.add(tf.summary.scalar('total_loss', total_loss))
- # Create gradient updates.
- grad_updates = optimizer.apply_gradients(clones_gradients,
- global_step=global_step)
- update_ops.append(grad_updates)
- update_op = tf.group(*update_ops)
- train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
- name='train_op')
- # Add the summaries from the first clone. These contain the summaries
- # created by model_fn and either optimize_clones() or _gather_clone_loss().
- summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
- first_clone_scope))
- # Merge all summaries together.
- summary_op = tf.summary.merge(list(summaries), name='summary_op')
- ###########################
- # Kicks off the training. #
- ###########################
- slim.learning.train(
- train_tensor,
- logdir=FLAGS.train_dir,
- master=FLAGS.master,
- is_chief=(FLAGS.task == 0),
- init_fn=_get_init_fn(),
- summary_op=summary_op,
- number_of_steps=FLAGS.max_number_of_steps,
- log_every_n_steps=FLAGS.log_every_n_steps,
- save_summaries_secs=FLAGS.save_summaries_secs,
- save_interval_secs=FLAGS.save_interval_secs,
- sync_optimizer=optimizer if FLAGS.sync_replicas else None)
- if __name__ == '__main__':
- tf.app.run()
|