train_gamma_normal_def.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Trains a recurrent DEF with gamma latent variables and gaussian weights.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import os
  21. import time
  22. import numpy as np
  23. from scipy.misc import imsave
  24. import tensorflow as tf
  25. from ops import inference
  26. from ops import model_factory
  27. from ops import rmsprop
  28. from ops import tf_lib
  29. from ops import util
  30. flags = tf.flags
  31. flags.DEFINE_string('master', 'local',
  32. 'BNS name of the TensorFlow master to use.')
  33. flags.DEFINE_string('logdir', '/tmp/write_logs',
  34. 'Directory where to write event logs.')
  35. flags.DEFINE_integer('seed', 41312, 'Random seed for TensorFlow and Numpy')
  36. flags.DEFINE_boolean('delete_logdir', True, 'Whether to clear the log dir.')
  37. flags.DEFINE_string('trials_root_dir',
  38. '/tmp/logs',
  39. 'Directory where to write event logs.')
  40. flags.DEFINE_integer(
  41. 'save_summaries_secs', 10,
  42. 'The frequency with which summaries are saved, in seconds.')
  43. flags.DEFINE_integer('save_interval_secs', 10,
  44. 'The frequency with which the model is saved, in seconds.')
  45. flags.DEFINE_integer('max_steps', 200000,
  46. 'The maximum number of gradient steps.')
  47. flags.DEFINE_integer('print_stats_every', 100, 'print stats every')
  48. flags.DEFINE_integer(
  49. 'ps_tasks', 0,
  50. 'The number of parameter servers. If the value is 0, then the parameters '
  51. 'are handled locally by the worker.')
  52. flags.DEFINE_integer(
  53. 'task', 0,
  54. 'The Task ID. This value is used when training with multiple workers to '
  55. 'identify each worker.')
  56. flags.DEFINE_string('trainer', 'supervisor', 'slim/local/supervisor')
  57. flags.DEFINE_integer('samples_to_save', 1, 'number of samples to save')
  58. flags.DEFINE_boolean('check_nans', False, 'add ops to check for nans.')
  59. flags.DEFINE_string('data_path',
  60. '/readahead/256M/cns/in-d/home/jaana/binarized_mnist_new',
  61. 'Where to read the data from.')
  62. FLAGS = flags.FLAGS
  63. sg = tf.contrib.bayesflow.stochastic_graph
  64. distributions = tf.contrib.distributions
  65. def run_training(hparams, train_dir, max_steps, tuner, container='',
  66. trainer='supervisor'):
  67. """Trains a Gaussian Recurrent DEF.
  68. Args:
  69. hparams: A tf.HParams object with hyperparameters for training.
  70. train_dir: Where to store events files and checkpoints.
  71. max_steps: Integer number of steps to train.
  72. tuner: An instance of a vizier tuner.
  73. container: String specifying container for resource sharing.
  74. trainer: Train locally by loading an hdf5 file or with Supervisor.
  75. Returns:
  76. sess: Optionally, the session for training.
  77. vi: Optionally, VariationalInference object that has been trained.
  78. Raises:
  79. ValueError: if ELBO is nan.
  80. """
  81. hps = hparams
  82. tf.set_random_seed(FLAGS.seed)
  83. np.random.seed(FLAGS.seed)
  84. g = tf.Graph()
  85. if FLAGS.ps_tasks > 0:
  86. device_fn = tf.ReplicaDeviceSetter(FLAGS.ps_tasks)
  87. else:
  88. device_fn = None
  89. with g.as_default(), g.device(device_fn), tf.container(container):
  90. if trainer == 'local':
  91. x_indexes = tf.placeholder(tf.int32, [hps.batch_size])
  92. x = tf.placeholder(tf.float32,
  93. [hps.batch_size, hps.n_timesteps, hps.timestep_dim, 1])
  94. data_iterator = util.provide_hdf5_data(
  95. FLAGS.data_path,
  96. 'train',
  97. hps.n_examples,
  98. hps.batch_size,
  99. hps.n_timesteps,
  100. hps.timestep_dim,
  101. hps.dataset)
  102. else:
  103. x_indexes, x = util.provide_tfrecords_data(
  104. FLAGS.data_path,
  105. 'train_labeled',
  106. hps.batch_size,
  107. hps.n_timesteps,
  108. hps.timestep_dim)
  109. data = {'x': x, 'x_indexes': x_indexes}
  110. model = model_factory.GammaNormalRDEF(
  111. n_timesteps=hps.n_timesteps,
  112. batch_size=hps.batch_size,
  113. p_z_shape=hps.p_z_shape,
  114. p_z_mean=hps.p_z_mean,
  115. p_w_mean_sigma=hps.p_w_mean_sigma,
  116. fixed_p_z_mean=hps.fixed_p_z_mean,
  117. p_w_shape_sigma=hps.p_w_shape_sigma,
  118. z_dim=hps.z_dim,
  119. use_bias_observations=hps.use_bias_observations,
  120. n_samples_latents=hps.n_samples_latents,
  121. dtype=hps.dtype)
  122. variational = model_factory.GammaNormalRDEFVariational(
  123. x_indexes=x_indexes,
  124. n_examples=hps.n_examples,
  125. n_timesteps=hps.n_timesteps,
  126. z_dim=hps.z_dim,
  127. timestep_dim=hps.timestep_dim,
  128. init_shape_q_z=hps.init_shape_q_z,
  129. init_mean_q_z=hps.init_mean_q_z,
  130. init_sigma_q_w_mean=hps.p_w_mean_sigma * hps.init_q_sigma_scale,
  131. init_sigma_q_w_shape=hps.p_w_shape_sigma * hps.init_q_sigma_scale,
  132. init_sigma_q_w_0_sigma=hps.p_w_shape_sigma * hps.init_q_sigma_scale,
  133. fixed_p_z_mean=hps.fixed_p_z_mean,
  134. fixed_q_z_mean=hps.fixed_q_z_mean,
  135. fixed_q_w_mean_sigma=hps.fixed_q_w_mean_sigma,
  136. fixed_q_w_shape_sigma=hps.fixed_q_w_shape_sigma,
  137. fixed_q_w_0_sigma=hps.fixed_q_w_0_sigma,
  138. n_samples_latents=hps.n_samples_latents,
  139. use_bias_observations=hps.use_bias_observations,
  140. dtype=hps.dtype)
  141. vi = inference.VariationalInference(model, variational, data)
  142. vi.build_graph()
  143. # Build prior and posterior predictive samples
  144. z_1_prior_sample = model.recurrent_layer_sample(
  145. variational.sample['w_1_shape'], variational.sample['w_1_mean'],
  146. hps.batch_size)
  147. prior_predictive = model.likelihood_sample(
  148. variational.sample, z_1_prior_sample, hps.batch_size)
  149. posterior_predictive = model.likelihood_sample(
  150. variational.sample, variational.sample['z_1'], hps.batch_size)
  151. # Build summaries.
  152. float32 = lambda x: tf.cast(x, tf.float32)
  153. tf.image_summary('prior_predictive',
  154. float32(prior_predictive),
  155. max_images=10)
  156. tf.image_summary('posterior_predictive',
  157. float32(posterior_predictive),
  158. max_images=10)
  159. tf.scalar_summary('ELBO', vi.scalar_elbo / hps.batch_size)
  160. tf.scalar_summary('log_p', tf.reduce_mean(vi.log_p))
  161. tf.scalar_summary('log_q', tf.reduce_mean(vi.log_q))
  162. global_step = tf.contrib.framework.get_or_create_global_step()
  163. # Specify optimization scheme.
  164. optimizer = tf.train.AdamOptimizer(learning_rate=hps.learning_rate)
  165. if hps.control_variate == 'none':
  166. train_op = optimizer.minimize(-vi.surrogate_elbo, global_step=global_step)
  167. elif hps.control_variate == 'covariance':
  168. train_non_reparam = rmsprop.maximize_with_control_variate(
  169. learning_rate=hps.learning_rate,
  170. learning_signal=vi.elbo,
  171. log_prob=vi.log_q,
  172. variable_list=tf.get_collection('non_reparam_variables'),
  173. global_step=global_step)
  174. grad_tensors = [v.values if 'embedding_lookup' in v.name else v
  175. for v in tf.get_collection('non_reparam_variable_grads')]
  176. train_reparam = optimizer.minimize(
  177. -tf.reduce_mean(vi.elbo, 0), # optimize the mean across samples
  178. var_list=tf.get_collection('reparam_variables'))
  179. train_op = tf.group(train_reparam, train_non_reparam)
  180. if trainer == 'supervisor':
  181. global_step = tf.contrib.framework.get_or_create_global_step()
  182. train_op = optimizer.minimize(-vi.elbo, global_step=global_step)
  183. summary_op = tf.merge_all_summaries()
  184. saver = tf.train.Saver()
  185. sv = tf.Supervisor(
  186. logdir=train_dir,
  187. is_chief=(FLAGS.task == 0),
  188. saver=saver,
  189. summary_op=summary_op,
  190. global_step=global_step,
  191. save_summaries_secs=FLAGS.save_summaries_secs,
  192. save_model_secs=FLAGS.save_summaries_secs,
  193. recovery_wait_secs=5)
  194. sess = sv.PrepareSession(FLAGS.master)
  195. sv.StartQueueRunners(sess)
  196. local_step = 0
  197. while not sv.ShouldStop():
  198. _, np_elbo, np_global_step = sess.run(
  199. [train_op, vi.elbo, global_step])
  200. if tuner is not None:
  201. if np.isnan(np_elbo):
  202. tuner.report_done(infeasible=True, infeasible_reason='ELBO is nan')
  203. should_stop = True
  204. else:
  205. should_stop = tuner.report_measure(float(np_elbo),
  206. global_step=np_global_step)
  207. if should_stop:
  208. tuner.report_done()
  209. sv.RequestStop()
  210. if np_global_step >= max_steps:
  211. break
  212. if local_step % FLAGS.print_stats_every == 0:
  213. print 'step %d: %g' % (np_global_step - 1, np_elbo / hps.batch_size)
  214. local_step += 1
  215. sv.Stop()
  216. sess.close()
  217. elif trainer == 'local':
  218. sess = tf.InteractiveSession()
  219. sess.run(tf.initialize_all_variables())
  220. t0 = time.time()
  221. if tf.gfile.Exists(train_dir):
  222. tf.gfile.DeleteRecursively(train_dir)
  223. tf.gfile.MakeDirs(train_dir)
  224. else:
  225. tf.gfile.MakeDirs(train_dir)
  226. for i in range(max_steps):
  227. indexes, images = data_iterator.next()
  228. feed_dict = {x_indexes: indexes, x: images}
  229. if i % FLAGS.print_stats_every == 0:
  230. _, np_prior_predictive, np_posterior_predictive = sess.run(
  231. [train_op, prior_predictive, posterior_predictive],
  232. feed_dict)
  233. print 'prior_predictive', np_prior_predictive.flatten()
  234. print 'posterior_predictive', np_posterior_predictive.flatten()
  235. print 'data', images.flatten()
  236. examples_per_s = (hps.batch_size * FLAGS.print_stats_every /
  237. (time.time() - t0))
  238. q_z = variational.params['z_1'].distribution
  239. alpha = q_z.alpha
  240. beta = q_z.beta
  241. mean = alpha / beta
  242. grad_list = []
  243. elbo_list = []
  244. for k in range(100):
  245. elbo_list.append(vi.elbo.eval(feed_dict))
  246. grads = sess.run(grad_tensors, feed_dict)
  247. grad_list.append(grads)
  248. np_elbo = np.mean(np.vstack([np.sum(v, axis=1) for v in elbo_list]))
  249. if np.isnan(np_elbo):
  250. raise ValueError('ELBO is NaN. Please keep trying!')
  251. grads_per_var = [np.stack(
  252. [g_sample[var_idx] for g_sample in grad_list])
  253. for var_idx in range(
  254. len(tf.get_collection(
  255. 'non_reparam_variable_grads')))]
  256. grads_per_timestep = [np.split(g, hps.n_timesteps, axis=2)
  257. for g in grads_per_var]
  258. grads_per_timestep_per_dim = [[np.split(g, hps.z_dim, axis=3) for g in
  259. g_list] for g_list
  260. in grads_per_timestep]
  261. grads_per_timestep_per_dim = [sum(g_list, []) for g_list in
  262. grads_per_timestep_per_dim]
  263. print 'variance of gradients for each variable: '
  264. for var_idx, var in enumerate(
  265. tf.get_collection('non_reparam_variable_grads')):
  266. print 'variable: %s' % var.name
  267. var = [np.var(g, axis=0) for g in
  268. grads_per_timestep_per_dim[var_idx]]
  269. print 'variance is: ', np.stack(var).flatten()
  270. print 'alpha ', alpha.eval(feed_dict).flatten()
  271. print 'mean ', mean.eval(feed_dict).flatten()
  272. print 'bernoulli p ', np.mean(
  273. vi.model.p_x_zw_bernoulli_p.eval(feed_dict), axis=0).flatten()
  274. t0 = time.time()
  275. print 'iter %d\telbo: %.3e\texamples/s: %.3f' % (
  276. i, np_elbo, examples_per_s)
  277. for k in range(hps.samples_to_save):
  278. im_name = 'i_%d_k_%d_' % (i, k)
  279. prior_name = im_name + 'prior_predictive.jpg'
  280. posterior_name = im_name + 'posterior_predictive.jpg'
  281. imsave(os.path.join(train_dir, prior_name),
  282. np_prior_predictive[k, :, :, 0])
  283. imsave(os.path.join(train_dir, posterior_name),
  284. np_posterior_predictive[k, :, :, 0])
  285. else:
  286. _ = sess.run(train_op, feed_dict)
  287. return vi, sess
  288. def main(unused_argv):
  289. """Trains a gaussian def locally."""
  290. if tf.gfile.Exists(FLAGS.logdir) and FLAGS.delete_logdir:
  291. tf.gfile.DeleteRecursively(FLAGS.logdir)
  292. tf.gfile.MakeDirs(FLAGS.logdir)
  293. # The HParams are commented in training_params.py
  294. try:
  295. hparams = tf.HParams
  296. except AttributeError:
  297. hparams = tf_lib.HParams
  298. hparams = hparams(
  299. dataset='alternating',
  300. z_dim=1,
  301. timestep_dim=1,
  302. n_timesteps=2,
  303. batch_size=1,
  304. n_examples=1,
  305. samples_to_save=1,
  306. learning_rate=0.01,
  307. momentum=0.0,
  308. n_samples_latents=100,
  309. p_z_shape=0.1,
  310. p_z_mean=1.,
  311. p_w_mean_sigma=5.,
  312. p_w_shape_sigma=5.,
  313. init_q_sigma_scale=0.1,
  314. use_bias_observations=True,
  315. init_q_z_scale=1.,
  316. init_shape_q_z=util.softplus(0.1),
  317. init_mean_q_z=util.softplus(0.01),
  318. fixed_p_z_mean=False,
  319. fixed_q_z_mean=False,
  320. fixed_q_z_shape=False,
  321. fixed_q_w_mean_sigma=False,
  322. fixed_q_w_shape_sigma=False,
  323. fixed_q_w_0_sigma=False,
  324. dtype='float64',
  325. control_variate='covariance')
  326. run_training(hparams, FLAGS.logdir, FLAGS.max_steps, None, trainer='local')
  327. if __name__ == '__main__':
  328. tf.app.run()