# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=line-too-long """A binary to train Inception in a distributed manner using multiple systems. Please see accompanying README.md for details and instructions. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from inception import inception_distributed_train from inception.imagenet_data import ImagenetData FLAGS = tf.app.flags.FLAGS def main(unused_args): assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker' # Extract all the hostnames for the ps and worker jobs to construct the # cluster spec. ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') tf.logging.info('PS hosts are: %s' % ps_hosts) tf.logging.info('Worker hosts are: %s' % worker_hosts) cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts}) server = tf.train.Server( {'ps': ps_hosts, 'worker': worker_hosts}, job_name=FLAGS.job_name, task_index=FLAGS.task_id) if FLAGS.job_name == 'ps': # `ps` jobs wait for incoming connections from the workers. server.join() else: # `worker` jobs will actually do the work. dataset = ImagenetData(subset=FLAGS.subset) assert dataset.data_files() # Only the chief checks for or creates train_dir. if FLAGS.task_id == 0: if not tf.gfile.Exists(FLAGS.train_dir): tf.gfile.MakeDirs(FLAGS.train_dir) inception_distributed_train.train(server.target, dataset, cluster_spec) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()