imagenet_distributed_train.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # pylint: disable=line-too-long
  16. """A binary to train Inception in a distributed manner using multiple systems.
  17. Please see accompanying README.md for details and instructions.
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import tensorflow as tf
  23. from inception import inception_distributed_train
  24. from inception.imagenet_data import ImagenetData
  25. FLAGS = tf.app.flags.FLAGS
  26. def main(unused_args):
  27. assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'
  28. # Extract all the hostnames for the ps and worker jobs to construct the
  29. # cluster spec.
  30. ps_hosts = FLAGS.ps_hosts.split(',')
  31. worker_hosts = FLAGS.worker_hosts.split(',')
  32. tf.logging.info('PS hosts are: %s' % ps_hosts)
  33. tf.logging.info('Worker hosts are: %s' % worker_hosts)
  34. cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
  35. 'worker': worker_hosts})
  36. server = tf.train.Server(
  37. {'ps': ps_hosts,
  38. 'worker': worker_hosts},
  39. job_name=FLAGS.job_name,
  40. task_index=FLAGS.task_id)
  41. if FLAGS.job_name == 'ps':
  42. # `ps` jobs wait for incoming connections from the workers.
  43. server.join()
  44. else:
  45. # `worker` jobs will actually do the work.
  46. dataset = ImagenetData(subset=FLAGS.subset)
  47. assert dataset.data_files()
  48. # Only the chief checks for or creates train_dir.
  49. if FLAGS.task_id == 0:
  50. if not tf.gfile.Exists(FLAGS.train_dir):
  51. tf.gfile.MakeDirs(FLAGS.train_dir)
  52. inception_distributed_train.train(server.target, dataset, cluster_spec)
  53. if __name__ == '__main__':
  54. tf.logging.set_verbosity(tf.logging.INFO)
  55. tf.app.run()