real_nvp_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Utility functions for Real NVP.
  16. """
  17. # pylint: disable=dangerous-default-value
  18. import numpy
  19. import tensorflow as tf
  20. from tensorflow.python.framework import ops
  21. DEFAULT_BN_LAG = .0
  22. def stable_var(input_, mean=None, axes=[0]):
  23. """Numerically more stable variance computation."""
  24. if mean is None:
  25. mean = tf.reduce_mean(input_, axes)
  26. res = tf.square(input_ - mean)
  27. max_sqr = tf.reduce_max(res, axes)
  28. res /= max_sqr
  29. res = tf.reduce_mean(res, axes)
  30. res *= max_sqr
  31. return res
  32. def variable_on_cpu(name, shape, initializer, trainable=True):
  33. """Helper to create a Variable stored on CPU memory.
  34. Args:
  35. name: name of the variable
  36. shape: list of ints
  37. initializer: initializer for Variable
  38. trainable: boolean defining if the variable is for training
  39. Returns:
  40. Variable Tensor
  41. """
  42. var = tf.get_variable(
  43. name, shape, initializer=initializer, trainable=trainable)
  44. return var
  45. # layers
  46. def conv_layer(input_,
  47. filter_size,
  48. dim_in,
  49. dim_out,
  50. name,
  51. stddev=1e-2,
  52. strides=[1, 1, 1, 1],
  53. padding="SAME",
  54. nonlinearity=None,
  55. bias=False,
  56. weight_norm=False,
  57. scale=False):
  58. """Convolutional layer."""
  59. with tf.variable_scope(name) as scope:
  60. weights = variable_on_cpu(
  61. "weights",
  62. filter_size + [dim_in, dim_out],
  63. tf.random_uniform_initializer(
  64. minval=-stddev, maxval=stddev))
  65. # weight normalization
  66. if weight_norm:
  67. weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
  68. if scale:
  69. magnitude = variable_on_cpu(
  70. "magnitude", [dim_out],
  71. tf.constant_initializer(
  72. stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
  73. weights *= magnitude
  74. res = input_
  75. # handling filter size bigger than image size
  76. if hasattr(input_, "shape"):
  77. if input_.get_shape().as_list()[1] < filter_size[0]:
  78. pad_1 = tf.zeros([
  79. input_.get_shape().as_list()[0],
  80. filter_size[0] - input_.get_shape().as_list()[1],
  81. input_.get_shape().as_list()[2],
  82. input_.get_shape().as_list()[3]
  83. ])
  84. pad_2 = tf.zeros([
  85. input_.get_shape().as_list[0],
  86. filter_size[0],
  87. filter_size[1] - input_.get_shape().as_list()[2],
  88. input_.get_shape().as_list()[3]
  89. ])
  90. res = tf.concat(axis=1, values=[pad_1, res])
  91. res = tf.concat(axis=2, values=[pad_2, res])
  92. res = tf.nn.conv2d(
  93. input=res,
  94. filter=weights,
  95. strides=strides,
  96. padding=padding,
  97. name=scope.name)
  98. if hasattr(input_, "shape"):
  99. if input_.get_shape().as_list()[1] < filter_size[0]:
  100. res = tf.slice(res, [
  101. 0, filter_size[0] - input_.get_shape().as_list()[1],
  102. filter_size[1] - input_.get_shape().as_list()[2], 0
  103. ], [-1, -1, -1, -1])
  104. if bias:
  105. biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
  106. res = tf.nn.bias_add(res, biases)
  107. if nonlinearity is not None:
  108. res = nonlinearity(res)
  109. return res
  110. def max_pool_2x2(input_):
  111. """Max pooling."""
  112. return tf.nn.max_pool(
  113. input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
  114. def depool_2x2(input_, stride=2):
  115. """Depooling."""
  116. shape = input_.get_shape().as_list()
  117. batch_size = shape[0]
  118. height = shape[1]
  119. width = shape[2]
  120. channels = shape[3]
  121. res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
  122. res = tf.concat(
  123. axis=2, values=[res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
  124. res = tf.concat(axis=4, values=[
  125. res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
  126. ])
  127. res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])
  128. return res
  129. # random flip on a batch of images
  130. def batch_random_flip(input_):
  131. """Simultaneous horizontal random flip."""
  132. if isinstance(input_, (float, int)):
  133. return input_
  134. shape = input_.get_shape().as_list()
  135. batch_size = shape[0]
  136. height = shape[1]
  137. width = shape[2]
  138. channels = shape[3]
  139. res = tf.split(axis=0, num_or_size_splits=batch_size, value=input_)
  140. res = [elem[0, :, :, :] for elem in res]
  141. res = [tf.image.random_flip_left_right(elem) for elem in res]
  142. res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
  143. res = tf.concat(axis=0, values=res)
  144. return res
  145. # build a one hot representation corresponding to the integer tensor
  146. # the one-hot dimension is appended to the integer tensor shape
  147. def as_one_hot(input_, n_indices):
  148. """Convert indices to one-hot."""
  149. shape = input_.get_shape().as_list()
  150. n_elem = numpy.prod(shape)
  151. indices = tf.range(n_elem)
  152. indices = tf.cast(indices, tf.int64)
  153. indices_input = tf.concat(axis=0, values=[indices, tf.reshape(input_, [-1])])
  154. indices_input = tf.reshape(indices_input, [2, -1])
  155. indices_input = tf.transpose(indices_input)
  156. res = tf.sparse_to_dense(
  157. indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
  158. res = tf.reshape(res, [elem for elem in shape] + [n_indices])
  159. return res
  160. def squeeze_2x2(input_):
  161. """Squeezing operation: reshape to convert space to channels."""
  162. return squeeze_nxn(input_, n_factor=2)
  163. def squeeze_nxn(input_, n_factor=2):
  164. """Squeezing operation: reshape to convert space to channels."""
  165. if isinstance(input_, (float, int)):
  166. return input_
  167. shape = input_.get_shape().as_list()
  168. batch_size = shape[0]
  169. height = shape[1]
  170. width = shape[2]
  171. channels = shape[3]
  172. if height % n_factor != 0:
  173. raise ValueError("Height not divisible by %d." % n_factor)
  174. if width % n_factor != 0:
  175. raise ValueError("Width not divisible by %d." % n_factor)
  176. res = tf.reshape(
  177. input_,
  178. [batch_size,
  179. height // n_factor,
  180. n_factor, width // n_factor,
  181. n_factor, channels])
  182. res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
  183. res = tf.reshape(
  184. res,
  185. [batch_size,
  186. height // n_factor,
  187. width // n_factor,
  188. channels * n_factor * n_factor])
  189. return res
  190. def unsqueeze_2x2(input_):
  191. """Unsqueezing operation: reshape to convert channels into space."""
  192. if isinstance(input_, (float, int)):
  193. return input_
  194. shape = input_.get_shape().as_list()
  195. batch_size = shape[0]
  196. height = shape[1]
  197. width = shape[2]
  198. channels = shape[3]
  199. if channels % 4 != 0:
  200. raise ValueError("Number of channels not divisible by 4.")
  201. res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
  202. res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
  203. res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])
  204. return res
  205. # batch norm
  206. def batch_norm(input_,
  207. dim,
  208. name,
  209. scale=True,
  210. train=True,
  211. epsilon=1e-8,
  212. decay=.1,
  213. axes=[0],
  214. bn_lag=DEFAULT_BN_LAG):
  215. """Batch normalization."""
  216. # create variables
  217. with tf.variable_scope(name):
  218. var = variable_on_cpu(
  219. "var", [dim], tf.constant_initializer(1.), trainable=False)
  220. mean = variable_on_cpu(
  221. "mean", [dim], tf.constant_initializer(0.), trainable=False)
  222. step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
  223. if scale:
  224. gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
  225. beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
  226. # choose the appropriate moments
  227. if train:
  228. used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
  229. cur_mean, cur_var = used_mean, used_var
  230. if bn_lag > 0.:
  231. used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
  232. used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
  233. used_mean /= (1. - bn_lag**(step + 1))
  234. used_var /= (1. - bn_lag**(step + 1))
  235. else:
  236. used_mean, used_var = mean, var
  237. cur_mean, cur_var = used_mean, used_var
  238. # normalize
  239. res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
  240. # de-normalize
  241. if scale:
  242. res *= gamma
  243. res += beta
  244. # update variables
  245. if train:
  246. with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
  247. with ops.colocate_with(mean):
  248. new_mean = tf.assign_sub(
  249. mean,
  250. tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
  251. with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
  252. with ops.colocate_with(var):
  253. new_var = tf.assign_sub(
  254. var,
  255. tf.check_numerics(decay * (var - cur_var),
  256. "NaN in moving variance."))
  257. with tf.name_scope(name, "IncrementTime", [step]):
  258. with ops.colocate_with(step):
  259. new_step = tf.assign_add(step, 1.)
  260. res += 0. * new_mean * new_var * new_step
  261. return res
  262. # batch normalization taking into account the volume transformation
  263. def batch_norm_log_diff(input_,
  264. dim,
  265. name,
  266. train=True,
  267. epsilon=1e-8,
  268. decay=.1,
  269. axes=[0],
  270. reuse=None,
  271. bn_lag=DEFAULT_BN_LAG):
  272. """Batch normalization with corresponding log determinant Jacobian."""
  273. if reuse is None:
  274. reuse = not train
  275. # create variables
  276. with tf.variable_scope(name) as scope:
  277. if reuse:
  278. scope.reuse_variables()
  279. var = variable_on_cpu(
  280. "var", [dim], tf.constant_initializer(1.), trainable=False)
  281. mean = variable_on_cpu(
  282. "mean", [dim], tf.constant_initializer(0.), trainable=False)
  283. step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
  284. # choose the appropriate moments
  285. if train:
  286. used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
  287. cur_mean, cur_var = used_mean, used_var
  288. if bn_lag > 0.:
  289. used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
  290. cur_var = used_var
  291. used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
  292. used_mean /= (1. - bn_lag**(step + 1))
  293. used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
  294. used_var /= (1. - bn_lag**(step + 1))
  295. else:
  296. used_mean, used_var = mean, var
  297. cur_mean, cur_var = used_mean, used_var
  298. # update variables
  299. if train:
  300. with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
  301. with ops.colocate_with(mean):
  302. new_mean = tf.assign_sub(
  303. mean,
  304. tf.check_numerics(
  305. decay * (mean - cur_mean), "NaN in moving mean."))
  306. with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
  307. with ops.colocate_with(var):
  308. new_var = tf.assign_sub(
  309. var,
  310. tf.check_numerics(decay * (var - cur_var),
  311. "NaN in moving variance."))
  312. with tf.name_scope(name, "IncrementTime", [step]):
  313. with ops.colocate_with(step):
  314. new_step = tf.assign_add(step, 1.)
  315. used_var += 0. * new_mean * new_var * new_step
  316. used_var += epsilon
  317. return used_mean, used_var
  318. def convnet(input_,
  319. dim_in,
  320. dim_hid,
  321. filter_sizes,
  322. dim_out,
  323. name,
  324. use_batch_norm=True,
  325. train=True,
  326. nonlinearity=tf.nn.relu):
  327. """Chaining of convolutional layers."""
  328. dims_in = [dim_in] + dim_hid[:-1]
  329. dims_out = dim_hid
  330. res = input_
  331. bias = (not use_batch_norm)
  332. with tf.variable_scope(name):
  333. for layer_idx in xrange(len(dim_hid)):
  334. res = conv_layer(
  335. input_=res,
  336. filter_size=filter_sizes[layer_idx],
  337. dim_in=dims_in[layer_idx],
  338. dim_out=dims_out[layer_idx],
  339. name="h_%d" % layer_idx,
  340. stddev=1e-2,
  341. nonlinearity=None,
  342. bias=bias)
  343. if use_batch_norm:
  344. res = batch_norm(
  345. input_=res,
  346. dim=dims_out[layer_idx],
  347. name="bn_%d" % layer_idx,
  348. scale=(nonlinearity == tf.nn.relu),
  349. train=train,
  350. epsilon=1e-8,
  351. axes=[0, 1, 2])
  352. if nonlinearity is not None:
  353. res = nonlinearity(res)
  354. res = conv_layer(
  355. input_=res,
  356. filter_size=filter_sizes[-1],
  357. dim_in=dims_out[-1],
  358. dim_out=dim_out,
  359. name="out",
  360. stddev=1e-2,
  361. nonlinearity=None)
  362. return res
  363. # distributions
  364. # log-likelihood estimation
  365. def standard_normal_ll(input_):
  366. """Log-likelihood of standard Gaussian distribution."""
  367. res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))
  368. return res
  369. def standard_normal_sample(shape):
  370. """Samples from standard Gaussian distribution."""
  371. return tf.random_normal(shape)
  372. SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
  373. [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])
  374. def squeeze_2x2_ordered(input_, reverse=False):
  375. """Squeezing operation with a controlled ordering."""
  376. shape = input_.get_shape().as_list()
  377. batch_size = shape[0]
  378. height = shape[1]
  379. width = shape[2]
  380. channels = shape[3]
  381. if reverse:
  382. if channels % 4 != 0:
  383. raise ValueError("Number of channels not divisible by 4.")
  384. channels /= 4
  385. else:
  386. if height % 2 != 0:
  387. raise ValueError("Height not divisible by 2.")
  388. if width % 2 != 0:
  389. raise ValueError("Width not divisible by 2.")
  390. weights = numpy.zeros((2, 2, channels, 4 * channels))
  391. for idx_ch in xrange(channels):
  392. slice_2 = slice(idx_ch, (idx_ch + 1))
  393. slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
  394. weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
  395. shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
  396. shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
  397. shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
  398. shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
  399. shuffle_channels = numpy.array(shuffle_channels)
  400. weights = weights[:, :, :, shuffle_channels].astype("float32")
  401. if reverse:
  402. res = tf.nn.conv2d_transpose(
  403. value=input_,
  404. filter=weights,
  405. output_shape=[batch_size, height * 2, width * 2, channels],
  406. strides=[1, 2, 2, 1],
  407. padding="SAME",
  408. name="unsqueeze_2x2")
  409. else:
  410. res = tf.nn.conv2d(
  411. input=input_,
  412. filter=weights,
  413. strides=[1, 2, 2, 1],
  414. padding="SAME",
  415. name="squeeze_2x2")
  416. return res