| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- # pylint: disable=line-too-long
- """A binary to train Inception in a distributed manner using multiple systems.
- Please see accompanying README.md for details and instructions.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- from inception import inception_distributed_train
- from inception.imagenet_data import ImagenetData
- FLAGS = tf.app.flags.FLAGS
- def main(unused_args):
- assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'
- # Extract all the hostnames for the ps and worker jobs to construct the
- # cluster spec.
- ps_hosts = FLAGS.ps_hosts.split(',')
- worker_hosts = FLAGS.worker_hosts.split(',')
- tf.logging.info('PS hosts are: %s' % ps_hosts)
- tf.logging.info('Worker hosts are: %s' % worker_hosts)
- cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
- 'worker': worker_hosts})
- server = tf.train.Server(
- {'ps': ps_hosts,
- 'worker': worker_hosts},
- job_name=FLAGS.job_name,
- task_index=FLAGS.task_id)
- if FLAGS.job_name == 'ps':
- # `ps` jobs wait for incoming connections from the workers.
- server.join()
- else:
- # `worker` jobs will actually do the work.
- dataset = ImagenetData(subset=FLAGS.subset)
- assert dataset.data_files()
- # Only the chief checks for or creates train_dir.
- if FLAGS.task_id == 0:
- if not tf.gfile.Exists(FLAGS.train_dir):
- tf.gfile.MakeDirs(FLAGS.train_dir)
- inception_distributed_train.train(server.target, dataset, cluster_spec)
- if __name__ == '__main__':
- tf.logging.set_verbosity(tf.logging.INFO)
- tf.app.run()
|