real_nvp_multiscale_dataset.py 68 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Script for training, evaluation and sampling for Real NVP.
  16. $ python real_nvp_multiscale_dataset.py \
  17. --alsologtostderr \
  18. --image_size 64 \
  19. --hpconfig=n_scale=5,base_dim=8 \
  20. --dataset imnet \
  21. --data_path [DATA_PATH]
  22. """
  23. import time
  24. from datetime import datetime
  25. import os
  26. import numpy
  27. import tensorflow as tf
  28. from tensorflow import gfile
  29. from real_nvp_utils import (
  30. batch_norm, batch_norm_log_diff, conv_layer,
  31. squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll,
  32. standard_normal_sample, unsqueeze_2x2, variable_on_cpu)
  33. tf.flags.DEFINE_string("master", "local",
  34. "BNS name of the TensorFlow master, or local.")
  35. tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale",
  36. "Directory to which writes logs.")
  37. tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale",
  38. "Directory to which writes logs.")
  39. tf.flags.DEFINE_integer("train_steps", 1000000000000000000,
  40. "Number of steps to train for.")
  41. tf.flags.DEFINE_string("data_path", "", "Path to the data.")
  42. tf.flags.DEFINE_string("mode", "train",
  43. "Mode of execution. Must be 'train', "
  44. "'sample' or 'eval'.")
  45. tf.flags.DEFINE_string("dataset", "imnet",
  46. "Dataset used. Must be 'imnet', "
  47. "'celeba' or 'lsun'.")
  48. tf.flags.DEFINE_integer("recursion_type", 2,
  49. "Type of the recursion.")
  50. tf.flags.DEFINE_integer("image_size", 64,
  51. "Size of the input image.")
  52. tf.flags.DEFINE_integer("eval_set_size", 0,
  53. "Size of evaluation dataset.")
  54. tf.flags.DEFINE_string(
  55. "hpconfig", "",
  56. "A comma separated list of hyperparameters for the model. Format is "
  57. "hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained "
  58. "with the specified hyperparameters, filling in missing hyperparameters "
  59. "from the default_values in |hyper_params|.")
  60. FLAGS = tf.flags.FLAGS
  61. class HParams(object):
  62. """Dictionary of hyperparameters."""
  63. def __init__(self, **kwargs):
  64. self.dict_ = kwargs
  65. self.__dict__.update(self.dict_)
  66. def update_config(self, in_string):
  67. """Update the dictionary with a comma separated list."""
  68. pairs = in_string.split(",")
  69. pairs = [pair.split("=") for pair in pairs]
  70. for key, val in pairs:
  71. self.dict_[key] = type(self.dict_[key])(val)
  72. self.__dict__.update(self.dict_)
  73. return self
  74. def __getitem__(self, key):
  75. return self.dict_[key]
  76. def __setitem__(self, key, val):
  77. self.dict_[key] = val
  78. self.__dict__.update(self.dict_)
  79. def get_default_hparams():
  80. """Get the default hyperparameters."""
  81. return HParams(
  82. batch_size=64,
  83. residual_blocks=2,
  84. n_couplings=2,
  85. n_scale=4,
  86. learning_rate=0.001,
  87. momentum=1e-1,
  88. decay=1e-3,
  89. l2_coeff=0.00005,
  90. clip_gradient=100.,
  91. optimizer="adam",
  92. dropout_mask=0,
  93. base_dim=32,
  94. bottleneck=0,
  95. use_batch_norm=1,
  96. alternate=1,
  97. use_aff=1,
  98. skip=1,
  99. data_constraint=.9,
  100. n_opt=0)
  101. # RESNET UTILS
  102. def residual_block(input_, dim, name, use_batch_norm=True,
  103. train=True, weight_norm=True, bottleneck=False):
  104. """Residual convolutional block."""
  105. with tf.variable_scope(name):
  106. res = input_
  107. if use_batch_norm:
  108. res = batch_norm(
  109. input_=res, dim=dim, name="bn_in", scale=False,
  110. train=train, epsilon=1e-4, axes=[0, 1, 2])
  111. res = tf.nn.relu(res)
  112. if bottleneck:
  113. res = conv_layer(
  114. input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
  115. name="h_0", stddev=numpy.sqrt(2. / (dim)),
  116. strides=[1, 1, 1, 1], padding="SAME",
  117. nonlinearity=None, bias=(not use_batch_norm),
  118. weight_norm=weight_norm, scale=False)
  119. if use_batch_norm:
  120. res = batch_norm(
  121. input_=res, dim=dim,
  122. name="bn_0", scale=False, train=train,
  123. epsilon=1e-4, axes=[0, 1, 2])
  124. res = tf.nn.relu(res)
  125. res = conv_layer(
  126. input_=res, filter_size=[3, 3], dim_in=dim,
  127. dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
  128. strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
  129. bias=(not use_batch_norm),
  130. weight_norm=weight_norm, scale=False)
  131. if use_batch_norm:
  132. res = batch_norm(
  133. input_=res, dim=dim, name="bn_1", scale=False,
  134. train=train, epsilon=1e-4, axes=[0, 1, 2])
  135. res = tf.nn.relu(res)
  136. res = conv_layer(
  137. input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
  138. name="out", stddev=numpy.sqrt(2. / (1. * dim)),
  139. strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
  140. bias=True, weight_norm=weight_norm, scale=True)
  141. else:
  142. res = conv_layer(
  143. input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
  144. name="h_0", stddev=numpy.sqrt(2. / (dim)),
  145. strides=[1, 1, 1, 1], padding="SAME",
  146. nonlinearity=None, bias=(not use_batch_norm),
  147. weight_norm=weight_norm, scale=False)
  148. if use_batch_norm:
  149. res = batch_norm(
  150. input_=res, dim=dim, name="bn_0", scale=False,
  151. train=train, epsilon=1e-4, axes=[0, 1, 2])
  152. res = tf.nn.relu(res)
  153. res = conv_layer(
  154. input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
  155. name="out", stddev=numpy.sqrt(2. / (1. * dim)),
  156. strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
  157. bias=True, weight_norm=weight_norm, scale=True)
  158. res += input_
  159. return res
  160. def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True,
  161. train=True, weight_norm=True, residual_blocks=5,
  162. bottleneck=False, skip=True):
  163. """Residual convolutional network."""
  164. with tf.variable_scope(name):
  165. res = input_
  166. if residual_blocks != 0:
  167. res = conv_layer(
  168. input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
  169. name="h_in", stddev=numpy.sqrt(2. / (dim_in)),
  170. strides=[1, 1, 1, 1], padding="SAME",
  171. nonlinearity=None, bias=True,
  172. weight_norm=weight_norm, scale=False)
  173. if skip:
  174. out = conv_layer(
  175. input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
  176. name="skip_in", stddev=numpy.sqrt(2. / (dim)),
  177. strides=[1, 1, 1, 1], padding="SAME",
  178. nonlinearity=None, bias=True,
  179. weight_norm=weight_norm, scale=True)
  180. # residual blocks
  181. for idx_block in xrange(residual_blocks):
  182. res = residual_block(res, dim, "block_%d" % idx_block,
  183. use_batch_norm=use_batch_norm, train=train,
  184. weight_norm=weight_norm,
  185. bottleneck=bottleneck)
  186. if skip:
  187. out += conv_layer(
  188. input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
  189. name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)),
  190. strides=[1, 1, 1, 1], padding="SAME",
  191. nonlinearity=None, bias=True,
  192. weight_norm=weight_norm, scale=True)
  193. # outputs
  194. if skip:
  195. res = out
  196. if use_batch_norm:
  197. res = batch_norm(
  198. input_=res, dim=dim, name="bn_pre_out", scale=False,
  199. train=train, epsilon=1e-4, axes=[0, 1, 2])
  200. res = tf.nn.relu(res)
  201. res = conv_layer(
  202. input_=res, filter_size=[1, 1], dim_in=dim,
  203. dim_out=dim_out,
  204. name="out", stddev=numpy.sqrt(2. / (1. * dim)),
  205. strides=[1, 1, 1, 1], padding="SAME",
  206. nonlinearity=None, bias=True,
  207. weight_norm=weight_norm, scale=True)
  208. else:
  209. if bottleneck:
  210. res = conv_layer(
  211. input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim,
  212. name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
  213. strides=[1, 1, 1, 1], padding="SAME",
  214. nonlinearity=None, bias=(not use_batch_norm),
  215. weight_norm=weight_norm, scale=False)
  216. if use_batch_norm:
  217. res = batch_norm(
  218. input_=res, dim=dim, name="bn_0", scale=False,
  219. train=train, epsilon=1e-4, axes=[0, 1, 2])
  220. res = tf.nn.relu(res)
  221. res = conv_layer(
  222. input_=res, filter_size=[3, 3], dim_in=dim,
  223. dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
  224. strides=[1, 1, 1, 1], padding="SAME",
  225. nonlinearity=None,
  226. bias=(not use_batch_norm),
  227. weight_norm=weight_norm, scale=False)
  228. if use_batch_norm:
  229. res = batch_norm(
  230. input_=res, dim=dim, name="bn_1", scale=False,
  231. train=train, epsilon=1e-4, axes=[0, 1, 2])
  232. res = tf.nn.relu(res)
  233. res = conv_layer(
  234. input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out,
  235. name="out", stddev=numpy.sqrt(2. / (1. * dim)),
  236. strides=[1, 1, 1, 1], padding="SAME",
  237. nonlinearity=None, bias=True,
  238. weight_norm=weight_norm, scale=True)
  239. else:
  240. res = conv_layer(
  241. input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
  242. name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
  243. strides=[1, 1, 1, 1], padding="SAME",
  244. nonlinearity=None, bias=(not use_batch_norm),
  245. weight_norm=weight_norm, scale=False)
  246. if use_batch_norm:
  247. res = batch_norm(
  248. input_=res, dim=dim, name="bn_0", scale=False,
  249. train=train, epsilon=1e-4, axes=[0, 1, 2])
  250. res = tf.nn.relu(res)
  251. res = conv_layer(
  252. input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out,
  253. name="out", stddev=numpy.sqrt(2. / (1. * dim)),
  254. strides=[1, 1, 1, 1], padding="SAME",
  255. nonlinearity=None, bias=True,
  256. weight_norm=weight_norm, scale=True)
  257. return res
  258. # COUPLING LAYERS
  259. # masked convolution implementations
  260. def masked_conv_aff_coupling(input_, mask_in, dim, name,
  261. use_batch_norm=True, train=True, weight_norm=True,
  262. reverse=False, residual_blocks=5,
  263. bottleneck=False, use_width=1., use_height=1.,
  264. mask_channel=0., skip=True):
  265. """Affine coupling with masked convolution."""
  266. with tf.variable_scope(name) as scope:
  267. if reverse or (not train):
  268. scope.reuse_variables()
  269. shape = input_.get_shape().as_list()
  270. batch_size = shape[0]
  271. height = shape[1]
  272. width = shape[2]
  273. channels = shape[3]
  274. # build mask
  275. mask = use_width * numpy.arange(width)
  276. mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
  277. mask = mask.astype("float32")
  278. mask = tf.mod(mask_in + mask, 2)
  279. mask = tf.reshape(mask, [-1, height, width, 1])
  280. if mask.get_shape().as_list()[0] == 1:
  281. mask = tf.tile(mask, [batch_size, 1, 1, 1])
  282. res = input_ * tf.mod(mask_channel + mask, 2)
  283. # initial input
  284. if use_batch_norm:
  285. res = batch_norm(
  286. input_=res, dim=channels, name="bn_in", scale=False,
  287. train=train, epsilon=1e-4, axes=[0, 1, 2])
  288. res *= 2.
  289. res = tf.concat_v2([res, -res], 3)
  290. res = tf.concat_v2([res, mask], 3)
  291. dim_in = 2. * channels + 1
  292. res = tf.nn.relu(res)
  293. res = resnet(input_=res, dim_in=dim_in, dim=dim,
  294. dim_out=2 * channels,
  295. name="resnet", use_batch_norm=use_batch_norm,
  296. train=train, weight_norm=weight_norm,
  297. residual_blocks=residual_blocks,
  298. bottleneck=bottleneck, skip=skip)
  299. mask = tf.mod(mask_channel + mask, 2)
  300. res = tf.split(axis=res, num_or_size_splits=2, value=3)
  301. shift, log_rescaling = res[-2], res[-1]
  302. scale = variable_on_cpu(
  303. "rescaling_scale", [],
  304. tf.constant_initializer(0.))
  305. shift = tf.reshape(
  306. shift, [batch_size, height, width, channels])
  307. log_rescaling = tf.reshape(
  308. log_rescaling, [batch_size, height, width, channels])
  309. log_rescaling = scale * tf.tanh(log_rescaling)
  310. if not use_batch_norm:
  311. scale_shift = variable_on_cpu(
  312. "scale_shift", [],
  313. tf.constant_initializer(0.))
  314. log_rescaling += scale_shift
  315. shift *= (1. - mask)
  316. log_rescaling *= (1. - mask)
  317. if reverse:
  318. res = input_
  319. if use_batch_norm:
  320. mean, var = batch_norm_log_diff(
  321. input_=res * (1. - mask), dim=channels, name="bn_out",
  322. train=False, epsilon=1e-4, axes=[0, 1, 2])
  323. log_var = tf.log(var)
  324. res *= tf.exp(.5 * log_var * (1. - mask))
  325. res += mean * (1. - mask)
  326. res *= tf.exp(-log_rescaling)
  327. res -= shift
  328. log_diff = -log_rescaling
  329. if use_batch_norm:
  330. log_diff += .5 * log_var * (1. - mask)
  331. else:
  332. res = input_
  333. res += shift
  334. res *= tf.exp(log_rescaling)
  335. log_diff = log_rescaling
  336. if use_batch_norm:
  337. mean, var = batch_norm_log_diff(
  338. input_=res * (1. - mask), dim=channels, name="bn_out",
  339. train=train, epsilon=1e-4, axes=[0, 1, 2])
  340. log_var = tf.log(var)
  341. res -= mean * (1. - mask)
  342. res *= tf.exp(-.5 * log_var * (1. - mask))
  343. log_diff -= .5 * log_var * (1. - mask)
  344. return res, log_diff
  345. def masked_conv_add_coupling(input_, mask_in, dim, name,
  346. use_batch_norm=True, train=True, weight_norm=True,
  347. reverse=False, residual_blocks=5,
  348. bottleneck=False, use_width=1., use_height=1.,
  349. mask_channel=0., skip=True):
  350. """Additive coupling with masked convolution."""
  351. with tf.variable_scope(name) as scope:
  352. if reverse or (not train):
  353. scope.reuse_variables()
  354. shape = input_.get_shape().as_list()
  355. batch_size = shape[0]
  356. height = shape[1]
  357. width = shape[2]
  358. channels = shape[3]
  359. # build mask
  360. mask = use_width * numpy.arange(width)
  361. mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
  362. mask = mask.astype("float32")
  363. mask = tf.mod(mask_in + mask, 2)
  364. mask = tf.reshape(mask, [-1, height, width, 1])
  365. if mask.get_shape().as_list()[0] == 1:
  366. mask = tf.tile(mask, [batch_size, 1, 1, 1])
  367. res = input_ * tf.mod(mask_channel + mask, 2)
  368. # initial input
  369. if use_batch_norm:
  370. res = batch_norm(
  371. input_=res, dim=channels, name="bn_in", scale=False,
  372. train=train, epsilon=1e-4, axes=[0, 1, 2])
  373. res *= 2.
  374. res = tf.concat_v2([res, -res], 3)
  375. res = tf.concat_v2([res, mask], 3)
  376. dim_in = 2. * channels + 1
  377. res = tf.nn.relu(res)
  378. shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
  379. name="resnet", use_batch_norm=use_batch_norm,
  380. train=train, weight_norm=weight_norm,
  381. residual_blocks=residual_blocks,
  382. bottleneck=bottleneck, skip=skip)
  383. mask = tf.mod(mask_channel + mask, 2)
  384. shift *= (1. - mask)
  385. # use_batch_norm = False
  386. if reverse:
  387. res = input_
  388. if use_batch_norm:
  389. mean, var = batch_norm_log_diff(
  390. input_=res * (1. - mask),
  391. dim=channels, name="bn_out", train=False, epsilon=1e-4)
  392. log_var = tf.log(var)
  393. res *= tf.exp(.5 * log_var * (1. - mask))
  394. res += mean * (1. - mask)
  395. res -= shift
  396. log_diff = tf.zeros_like(res)
  397. if use_batch_norm:
  398. log_diff += .5 * log_var * (1. - mask)
  399. else:
  400. res = input_
  401. res += shift
  402. log_diff = tf.zeros_like(res)
  403. if use_batch_norm:
  404. mean, var = batch_norm_log_diff(
  405. input_=res * (1. - mask), dim=channels,
  406. name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2])
  407. log_var = tf.log(var)
  408. res -= mean * (1. - mask)
  409. res *= tf.exp(-.5 * log_var * (1. - mask))
  410. log_diff -= .5 * log_var * (1. - mask)
  411. return res, log_diff
  412. def masked_conv_coupling(input_, mask_in, dim, name,
  413. use_batch_norm=True, train=True, weight_norm=True,
  414. reverse=False, residual_blocks=5,
  415. bottleneck=False, use_aff=True,
  416. use_width=1., use_height=1.,
  417. mask_channel=0., skip=True):
  418. """Coupling with masked convolution."""
  419. if use_aff:
  420. return masked_conv_aff_coupling(
  421. input_=input_, mask_in=mask_in, dim=dim, name=name,
  422. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  423. reverse=reverse, residual_blocks=residual_blocks,
  424. bottleneck=bottleneck, use_width=use_width, use_height=use_height,
  425. mask_channel=mask_channel, skip=skip)
  426. else:
  427. return masked_conv_add_coupling(
  428. input_=input_, mask_in=mask_in, dim=dim, name=name,
  429. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  430. reverse=reverse, residual_blocks=residual_blocks,
  431. bottleneck=bottleneck, use_width=use_width, use_height=use_height,
  432. mask_channel=mask_channel, skip=skip)
  433. # channel-axis splitting implementations
  434. def conv_ch_aff_coupling(input_, dim, name,
  435. use_batch_norm=True, train=True, weight_norm=True,
  436. reverse=False, residual_blocks=5,
  437. bottleneck=False, change_bottom=True, skip=True):
  438. """Affine coupling with channel-wise splitting."""
  439. with tf.variable_scope(name) as scope:
  440. if reverse or (not train):
  441. scope.reuse_variables()
  442. if change_bottom:
  443. input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
  444. else:
  445. canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
  446. shape = input_.get_shape().as_list()
  447. batch_size = shape[0]
  448. height = shape[1]
  449. width = shape[2]
  450. channels = shape[3]
  451. res = input_
  452. # initial input
  453. if use_batch_norm:
  454. res = batch_norm(
  455. input_=res, dim=channels, name="bn_in", scale=False,
  456. train=train, epsilon=1e-4, axes=[0, 1, 2])
  457. res = tf.concat_v2([res, -res], 3)
  458. dim_in = 2. * channels
  459. res = tf.nn.relu(res)
  460. res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels,
  461. name="resnet", use_batch_norm=use_batch_norm,
  462. train=train, weight_norm=weight_norm,
  463. residual_blocks=residual_blocks,
  464. bottleneck=bottleneck, skip=skip)
  465. shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3)
  466. scale = variable_on_cpu(
  467. "scale", [],
  468. tf.constant_initializer(1.))
  469. shift = tf.reshape(
  470. shift, [batch_size, height, width, channels])
  471. log_rescaling = tf.reshape(
  472. log_rescaling, [batch_size, height, width, channels])
  473. log_rescaling = scale * tf.tanh(log_rescaling)
  474. if not use_batch_norm:
  475. scale_shift = variable_on_cpu(
  476. "scale_shift", [],
  477. tf.constant_initializer(0.))
  478. log_rescaling += scale_shift
  479. if reverse:
  480. res = canvas
  481. if use_batch_norm:
  482. mean, var = batch_norm_log_diff(
  483. input_=res, dim=channels, name="bn_out", train=False,
  484. epsilon=1e-4, axes=[0, 1, 2])
  485. log_var = tf.log(var)
  486. res *= tf.exp(.5 * log_var)
  487. res += mean
  488. res *= tf.exp(-log_rescaling)
  489. res -= shift
  490. log_diff = -log_rescaling
  491. if use_batch_norm:
  492. log_diff += .5 * log_var
  493. else:
  494. res = canvas
  495. res += shift
  496. res *= tf.exp(log_rescaling)
  497. log_diff = log_rescaling
  498. if use_batch_norm:
  499. mean, var = batch_norm_log_diff(
  500. input_=res, dim=channels, name="bn_out", train=train,
  501. epsilon=1e-4, axes=[0, 1, 2])
  502. log_var = tf.log(var)
  503. res -= mean
  504. res *= tf.exp(-.5 * log_var)
  505. log_diff -= .5 * log_var
  506. if change_bottom:
  507. res = tf.concat_v2([input_, res], 3)
  508. log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
  509. else:
  510. res = tf.concat_v2([res, input_], 3)
  511. log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
  512. return res, log_diff
  513. def conv_ch_add_coupling(input_, dim, name,
  514. use_batch_norm=True, train=True, weight_norm=True,
  515. reverse=False, residual_blocks=5,
  516. bottleneck=False, change_bottom=True, skip=True):
  517. """Additive coupling with channel-wise splitting."""
  518. with tf.variable_scope(name) as scope:
  519. if reverse or (not train):
  520. scope.reuse_variables()
  521. if change_bottom:
  522. input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
  523. else:
  524. canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
  525. shape = input_.get_shape().as_list()
  526. channels = shape[3]
  527. res = input_
  528. # initial input
  529. if use_batch_norm:
  530. res = batch_norm(
  531. input_=res, dim=channels, name="bn_in", scale=False,
  532. train=train, epsilon=1e-4, axes=[0, 1, 2])
  533. res = tf.concat_v2([res, -res], 3)
  534. dim_in = 2. * channels
  535. res = tf.nn.relu(res)
  536. shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
  537. name="resnet", use_batch_norm=use_batch_norm,
  538. train=train, weight_norm=weight_norm,
  539. residual_blocks=residual_blocks,
  540. bottleneck=bottleneck, skip=skip)
  541. if reverse:
  542. res = canvas
  543. if use_batch_norm:
  544. mean, var = batch_norm_log_diff(
  545. input_=res, dim=channels, name="bn_out", train=False,
  546. epsilon=1e-4, axes=[0, 1, 2])
  547. log_var = tf.log(var)
  548. res *= tf.exp(.5 * log_var)
  549. res += mean
  550. res -= shift
  551. log_diff = tf.zeros_like(res)
  552. if use_batch_norm:
  553. log_diff += .5 * log_var
  554. else:
  555. res = canvas
  556. res += shift
  557. log_diff = tf.zeros_like(res)
  558. if use_batch_norm:
  559. mean, var = batch_norm_log_diff(
  560. input_=res, dim=channels, name="bn_out", train=train,
  561. epsilon=1e-4, axes=[0, 1, 2])
  562. log_var = tf.log(var)
  563. res -= mean
  564. res *= tf.exp(-.5 * log_var)
  565. log_diff -= .5 * log_var
  566. if change_bottom:
  567. res = tf.concat_v2([input_, res], 3)
  568. log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
  569. else:
  570. res = tf.concat_v2([res, input_], 3)
  571. log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
  572. return res, log_diff
  573. def conv_ch_coupling(input_, dim, name,
  574. use_batch_norm=True, train=True, weight_norm=True,
  575. reverse=False, residual_blocks=5,
  576. bottleneck=False, use_aff=True, change_bottom=True,
  577. skip=True):
  578. """Coupling with channel-wise splitting."""
  579. if use_aff:
  580. return conv_ch_aff_coupling(
  581. input_=input_, dim=dim, name=name,
  582. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  583. reverse=reverse, residual_blocks=residual_blocks,
  584. bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
  585. else:
  586. return conv_ch_add_coupling(
  587. input_=input_, dim=dim, name=name,
  588. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  589. reverse=reverse, residual_blocks=residual_blocks,
  590. bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
  591. # RECURSIVE USE OF COUPLING LAYERS
  592. def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
  593. use_batch_norm=True, weight_norm=True,
  594. train=True):
  595. """Recursion on coupling layers."""
  596. shape = input_.get_shape().as_list()
  597. channels = shape[3]
  598. residual_blocks = hps.residual_blocks
  599. base_dim = hps.base_dim
  600. mask = 1.
  601. use_aff = hps.use_aff
  602. res = input_
  603. skip = hps.skip
  604. log_diff = tf.zeros_like(input_)
  605. dim = base_dim
  606. if FLAGS.recursion_type < 4:
  607. dim *= 2 ** scale_idx
  608. with tf.variable_scope("scale_%d" % scale_idx):
  609. # initial coupling layers
  610. res, inc_log_diff = masked_conv_coupling(
  611. input_=res,
  612. mask_in=mask, dim=dim,
  613. name="coupling_0",
  614. use_batch_norm=use_batch_norm, train=train,
  615. weight_norm=weight_norm,
  616. reverse=False, residual_blocks=residual_blocks,
  617. bottleneck=hps.bottleneck, use_aff=use_aff,
  618. use_width=1., use_height=1., skip=skip)
  619. log_diff += inc_log_diff
  620. res, inc_log_diff = masked_conv_coupling(
  621. input_=res,
  622. mask_in=1. - mask, dim=dim,
  623. name="coupling_1",
  624. use_batch_norm=use_batch_norm, train=train,
  625. weight_norm=weight_norm,
  626. reverse=False, residual_blocks=residual_blocks,
  627. bottleneck=hps.bottleneck, use_aff=use_aff,
  628. use_width=1., use_height=1., skip=skip)
  629. log_diff += inc_log_diff
  630. res, inc_log_diff = masked_conv_coupling(
  631. input_=res,
  632. mask_in=mask, dim=dim,
  633. name="coupling_2",
  634. use_batch_norm=use_batch_norm, train=train,
  635. weight_norm=weight_norm,
  636. reverse=False, residual_blocks=residual_blocks,
  637. bottleneck=hps.bottleneck, use_aff=True,
  638. use_width=1., use_height=1., skip=skip)
  639. log_diff += inc_log_diff
  640. if scale_idx < (n_scale - 1):
  641. with tf.variable_scope("scale_%d" % scale_idx):
  642. res = squeeze_2x2(res)
  643. log_diff = squeeze_2x2(log_diff)
  644. res, inc_log_diff = conv_ch_coupling(
  645. input_=res,
  646. change_bottom=True, dim=2 * dim,
  647. name="coupling_4",
  648. use_batch_norm=use_batch_norm, train=train,
  649. weight_norm=weight_norm,
  650. reverse=False, residual_blocks=residual_blocks,
  651. bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
  652. log_diff += inc_log_diff
  653. res, inc_log_diff = conv_ch_coupling(
  654. input_=res,
  655. change_bottom=False, dim=2 * dim,
  656. name="coupling_5",
  657. use_batch_norm=use_batch_norm, train=train,
  658. weight_norm=weight_norm,
  659. reverse=False, residual_blocks=residual_blocks,
  660. bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
  661. log_diff += inc_log_diff
  662. res, inc_log_diff = conv_ch_coupling(
  663. input_=res,
  664. change_bottom=True, dim=2 * dim,
  665. name="coupling_6",
  666. use_batch_norm=use_batch_norm, train=train,
  667. weight_norm=weight_norm,
  668. reverse=False, residual_blocks=residual_blocks,
  669. bottleneck=hps.bottleneck, use_aff=True, skip=skip)
  670. log_diff += inc_log_diff
  671. res = unsqueeze_2x2(res)
  672. log_diff = unsqueeze_2x2(log_diff)
  673. if FLAGS.recursion_type > 1:
  674. res = squeeze_2x2_ordered(res)
  675. log_diff = squeeze_2x2_ordered(log_diff)
  676. if FLAGS.recursion_type > 2:
  677. res_1 = res[:, :, :, :channels]
  678. res_2 = res[:, :, :, channels:]
  679. log_diff_1 = log_diff[:, :, :, :channels]
  680. log_diff_2 = log_diff[:, :, :, channels:]
  681. else:
  682. res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
  683. log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
  684. res_1, inc_log_diff = rec_masked_conv_coupling(
  685. input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
  686. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  687. train=train)
  688. res = tf.concat_v2([res_1, res_2], 3)
  689. log_diff_1 += inc_log_diff
  690. log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
  691. res = squeeze_2x2_ordered(res, reverse=True)
  692. log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
  693. else:
  694. res = squeeze_2x2_ordered(res)
  695. log_diff = squeeze_2x2_ordered(log_diff)
  696. res, inc_log_diff = rec_masked_conv_coupling(
  697. input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
  698. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  699. train=train)
  700. log_diff += inc_log_diff
  701. res = squeeze_2x2_ordered(res, reverse=True)
  702. log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
  703. else:
  704. with tf.variable_scope("scale_%d" % scale_idx):
  705. res, inc_log_diff = masked_conv_coupling(
  706. input_=res,
  707. mask_in=1. - mask, dim=dim,
  708. name="coupling_3",
  709. use_batch_norm=use_batch_norm, train=train,
  710. weight_norm=weight_norm,
  711. reverse=False, residual_blocks=residual_blocks,
  712. bottleneck=hps.bottleneck, use_aff=True,
  713. use_width=1., use_height=1., skip=skip)
  714. log_diff += inc_log_diff
  715. return res, log_diff
  716. def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
  717. use_batch_norm=True, weight_norm=True,
  718. train=True):
  719. """Recursion on inverting coupling layers."""
  720. shape = input_.get_shape().as_list()
  721. channels = shape[3]
  722. residual_blocks = hps.residual_blocks
  723. base_dim = hps.base_dim
  724. mask = 1.
  725. use_aff = hps.use_aff
  726. res = input_
  727. log_diff = tf.zeros_like(input_)
  728. skip = hps.skip
  729. dim = base_dim
  730. if FLAGS.recursion_type < 4:
  731. dim *= 2 ** scale_idx
  732. if scale_idx < (n_scale - 1):
  733. if FLAGS.recursion_type > 1:
  734. res = squeeze_2x2_ordered(res)
  735. log_diff = squeeze_2x2_ordered(log_diff)
  736. if FLAGS.recursion_type > 2:
  737. res_1 = res[:, :, :, :channels]
  738. res_2 = res[:, :, :, channels:]
  739. log_diff_1 = log_diff[:, :, :, :channels]
  740. log_diff_2 = log_diff[:, :, :, channels:]
  741. else:
  742. res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
  743. log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
  744. res_1, log_diff_1 = rec_masked_deconv_coupling(
  745. input_=res_1, hps=hps,
  746. scale_idx=scale_idx + 1, n_scale=n_scale,
  747. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  748. train=train)
  749. res = tf.concat_v2([res_1, res_2], 3)
  750. log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
  751. res = squeeze_2x2_ordered(res, reverse=True)
  752. log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
  753. else:
  754. res = squeeze_2x2_ordered(res)
  755. log_diff = squeeze_2x2_ordered(log_diff)
  756. res, log_diff = rec_masked_deconv_coupling(
  757. input_=res, hps=hps,
  758. scale_idx=scale_idx + 1, n_scale=n_scale,
  759. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  760. train=train)
  761. res = squeeze_2x2_ordered(res, reverse=True)
  762. log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
  763. with tf.variable_scope("scale_%d" % scale_idx):
  764. res = squeeze_2x2(res)
  765. log_diff = squeeze_2x2(log_diff)
  766. res, inc_log_diff = conv_ch_coupling(
  767. input_=res,
  768. change_bottom=True, dim=2 * dim,
  769. name="coupling_6",
  770. use_batch_norm=use_batch_norm, train=train,
  771. weight_norm=weight_norm,
  772. reverse=True, residual_blocks=residual_blocks,
  773. bottleneck=hps.bottleneck, use_aff=True, skip=skip)
  774. log_diff += inc_log_diff
  775. res, inc_log_diff = conv_ch_coupling(
  776. input_=res,
  777. change_bottom=False, dim=2 * dim,
  778. name="coupling_5",
  779. use_batch_norm=use_batch_norm, train=train,
  780. weight_norm=weight_norm,
  781. reverse=True, residual_blocks=residual_blocks,
  782. bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
  783. log_diff += inc_log_diff
  784. res, inc_log_diff = conv_ch_coupling(
  785. input_=res,
  786. change_bottom=True, dim=2 * dim,
  787. name="coupling_4",
  788. use_batch_norm=use_batch_norm, train=train,
  789. weight_norm=weight_norm,
  790. reverse=True, residual_blocks=residual_blocks,
  791. bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
  792. log_diff += inc_log_diff
  793. res = unsqueeze_2x2(res)
  794. log_diff = unsqueeze_2x2(log_diff)
  795. else:
  796. with tf.variable_scope("scale_%d" % scale_idx):
  797. res, inc_log_diff = masked_conv_coupling(
  798. input_=res,
  799. mask_in=1. - mask, dim=dim,
  800. name="coupling_3",
  801. use_batch_norm=use_batch_norm, train=train,
  802. weight_norm=weight_norm,
  803. reverse=True, residual_blocks=residual_blocks,
  804. bottleneck=hps.bottleneck, use_aff=True,
  805. use_width=1., use_height=1., skip=skip)
  806. log_diff += inc_log_diff
  807. with tf.variable_scope("scale_%d" % scale_idx):
  808. res, inc_log_diff = masked_conv_coupling(
  809. input_=res,
  810. mask_in=mask, dim=dim,
  811. name="coupling_2",
  812. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  813. reverse=True, residual_blocks=residual_blocks,
  814. bottleneck=hps.bottleneck, use_aff=True,
  815. use_width=1., use_height=1., skip=skip)
  816. log_diff += inc_log_diff
  817. res, inc_log_diff = masked_conv_coupling(
  818. input_=res,
  819. mask_in=1. - mask, dim=dim,
  820. name="coupling_1",
  821. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  822. reverse=True, residual_blocks=residual_blocks,
  823. bottleneck=hps.bottleneck, use_aff=use_aff,
  824. use_width=1., use_height=1., skip=skip)
  825. log_diff += inc_log_diff
  826. res, inc_log_diff = masked_conv_coupling(
  827. input_=res,
  828. mask_in=mask, dim=dim,
  829. name="coupling_0",
  830. use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
  831. reverse=True, residual_blocks=residual_blocks,
  832. bottleneck=hps.bottleneck, use_aff=use_aff,
  833. use_width=1., use_height=1., skip=skip)
  834. log_diff += inc_log_diff
  835. return res, log_diff
  836. # ENCODER AND DECODER IMPLEMENTATIONS
  837. # start the recursions
  838. def encoder(input_, hps, n_scale, use_batch_norm=True,
  839. weight_norm=True, train=True):
  840. """Encoding/gaussianization function."""
  841. res = input_
  842. log_diff = tf.zeros_like(input_)
  843. res, inc_log_diff = rec_masked_conv_coupling(
  844. input_=res, hps=hps, scale_idx=0, n_scale=n_scale,
  845. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  846. train=train)
  847. log_diff += inc_log_diff
  848. return res, log_diff
  849. def decoder(input_, hps, n_scale, use_batch_norm=True,
  850. weight_norm=True, train=True):
  851. """Decoding/generator function."""
  852. res, log_diff = rec_masked_deconv_coupling(
  853. input_=input_, hps=hps, scale_idx=0, n_scale=n_scale,
  854. use_batch_norm=use_batch_norm, weight_norm=weight_norm,
  855. train=train)
  856. return res, log_diff
  857. class RealNVP(object):
  858. """Real NVP model."""
  859. def __init__(self, hps, sampling=False):
  860. # DATA TENSOR INSTANTIATION
  861. device = "/cpu:0"
  862. if FLAGS.dataset == "imnet":
  863. with tf.device(
  864. tf.train.replica_device_setter(0, worker_device=device)):
  865. filename_queue = tf.train.string_input_producer(
  866. gfile.Glob(FLAGS.data_path), num_epochs=None)
  867. reader = tf.TFRecordReader()
  868. _, serialized_example = reader.read(filename_queue)
  869. features = tf.parse_single_example(
  870. serialized_example,
  871. features={
  872. "image_raw": tf.FixedLenFeature([], tf.string),
  873. })
  874. image = tf.decode_raw(features["image_raw"], tf.uint8)
  875. image.set_shape([FLAGS.image_size * FLAGS.image_size * 3])
  876. image = tf.cast(image, tf.float32)
  877. if FLAGS.mode == "train":
  878. images = tf.train.shuffle_batch(
  879. [image], batch_size=hps.batch_size, num_threads=1,
  880. capacity=1000 + 3 * hps.batch_size,
  881. # Ensures a minimum amount of shuffling of examples.
  882. min_after_dequeue=1000)
  883. else:
  884. images = tf.train.batch(
  885. [image], batch_size=hps.batch_size, num_threads=1,
  886. capacity=1000 + 3 * hps.batch_size)
  887. self.x_orig = x_orig = images
  888. image_size = FLAGS.image_size
  889. x_in = tf.reshape(
  890. x_orig,
  891. [hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
  892. x_in = tf.clip_by_value(x_in, 0, 255)
  893. x_in = (tf.cast(x_in, tf.float32)
  894. + tf.random_uniform(tf.shape(x_in))) / 256.
  895. elif FLAGS.dataset == "celeba":
  896. with tf.device(
  897. tf.train.replica_device_setter(0, worker_device=device)):
  898. filename_queue = tf.train.string_input_producer(
  899. gfile.Glob(FLAGS.data_path), num_epochs=None)
  900. reader = tf.TFRecordReader()
  901. _, serialized_example = reader.read(filename_queue)
  902. features = tf.parse_single_example(
  903. serialized_example,
  904. features={
  905. "image_raw": tf.FixedLenFeature([], tf.string),
  906. })
  907. image = tf.decode_raw(features["image_raw"], tf.uint8)
  908. image.set_shape([218 * 178 * 3]) # 218, 178
  909. image = tf.cast(image, tf.float32)
  910. image = tf.reshape(image, [218, 178, 3])
  911. image = image[40:188, 15:163, :]
  912. if FLAGS.mode == "train":
  913. image = tf.image.random_flip_left_right(image)
  914. images = tf.train.shuffle_batch(
  915. [image], batch_size=hps.batch_size, num_threads=1,
  916. capacity=1000 + 3 * hps.batch_size,
  917. min_after_dequeue=1000)
  918. else:
  919. images = tf.train.batch(
  920. [image], batch_size=hps.batch_size, num_threads=1,
  921. capacity=1000 + 3 * hps.batch_size)
  922. self.x_orig = x_orig = images
  923. image_size = 64
  924. x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3])
  925. x_in = tf.image.resize_images(
  926. x_in, [64, 64], method=0, align_corners=False)
  927. x_in = (tf.cast(x_in, tf.float32)
  928. + tf.random_uniform(tf.shape(x_in))) / 256.
  929. elif FLAGS.dataset == "lsun":
  930. with tf.device(
  931. tf.train.replica_device_setter(0, worker_device=device)):
  932. filename_queue = tf.train.string_input_producer(
  933. gfile.Glob(FLAGS.data_path), num_epochs=None)
  934. reader = tf.TFRecordReader()
  935. _, serialized_example = reader.read(filename_queue)
  936. features = tf.parse_single_example(
  937. serialized_example,
  938. features={
  939. "image_raw": tf.FixedLenFeature([], tf.string),
  940. "height": tf.FixedLenFeature([], tf.int64),
  941. "width": tf.FixedLenFeature([], tf.int64),
  942. "depth": tf.FixedLenFeature([], tf.int64)
  943. })
  944. image = tf.decode_raw(features["image_raw"], tf.uint8)
  945. height = tf.reshape((features["height"], tf.int64)[0], [1])
  946. height = tf.cast(height, tf.int32)
  947. width = tf.reshape((features["width"], tf.int64)[0], [1])
  948. width = tf.cast(width, tf.int32)
  949. depth = tf.reshape((features["depth"], tf.int64)[0], [1])
  950. depth = tf.cast(depth, tf.int32)
  951. image = tf.reshape(image, tf.concat_v2([height, width, depth], 0))
  952. image = tf.random_crop(image, [64, 64, 3])
  953. if FLAGS.mode == "train":
  954. image = tf.image.random_flip_left_right(image)
  955. images = tf.train.shuffle_batch(
  956. [image], batch_size=hps.batch_size, num_threads=1,
  957. capacity=1000 + 3 * hps.batch_size,
  958. # Ensures a minimum amount of shuffling of examples.
  959. min_after_dequeue=1000)
  960. else:
  961. images = tf.train.batch(
  962. [image], batch_size=hps.batch_size, num_threads=1,
  963. capacity=1000 + 3 * hps.batch_size)
  964. self.x_orig = x_orig = images
  965. image_size = 64
  966. x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3])
  967. x_in = (tf.cast(x_in, tf.float32)
  968. + tf.random_uniform(tf.shape(x_in))) / 256.
  969. else:
  970. raise ValueError("Unknown dataset.")
  971. x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3])
  972. side_shown = int(numpy.sqrt(hps.batch_size))
  973. shown_x = tf.transpose(
  974. tf.reshape(
  975. x_in[:(side_shown * side_shown), :, :, :],
  976. [side_shown, image_size * side_shown, image_size, 3]),
  977. [0, 2, 1, 3])
  978. shown_x = tf.transpose(
  979. tf.reshape(
  980. shown_x,
  981. [1, image_size * side_shown, image_size * side_shown, 3]),
  982. [0, 2, 1, 3]) * 255.
  983. tf.summary.image(
  984. "inputs",
  985. tf.cast(shown_x, tf.uint8),
  986. max_outputs=1)
  987. # restrict the data
  988. FLAGS.image_size = image_size
  989. data_constraint = hps.data_constraint
  990. pre_logit_scale = numpy.log(data_constraint)
  991. pre_logit_scale -= numpy.log(1. - data_constraint)
  992. pre_logit_scale = tf.cast(pre_logit_scale, tf.float32)
  993. logit_x_in = 2. * x_in # [0, 2]
  994. logit_x_in -= 1. # [-1, 1]
  995. logit_x_in *= data_constraint # [-.9, .9]
  996. logit_x_in += 1. # [.1, 1.9]
  997. logit_x_in /= 2. # [.05, .95]
  998. # logit the data
  999. logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in)
  1000. transform_cost = tf.reduce_sum(
  1001. tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in)
  1002. - tf.nn.softplus(-pre_logit_scale),
  1003. [1, 2, 3])
  1004. # INFERENCE AND COSTS
  1005. z_out, log_diff = encoder(
  1006. input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
  1007. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1008. train=True)
  1009. if FLAGS.mode != "train":
  1010. z_out, log_diff = encoder(
  1011. input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
  1012. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1013. train=False)
  1014. final_shape = [image_size, image_size, 3]
  1015. prior_ll = standard_normal_ll(z_out)
  1016. prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3])
  1017. log_diff = tf.reduce_sum(log_diff, [1, 2, 3])
  1018. log_diff += transform_cost
  1019. cost = -(prior_ll + log_diff)
  1020. self.x_in = x_in
  1021. self.z_out = z_out
  1022. self.cost = cost = tf.reduce_mean(cost)
  1023. l2_reg = sum(
  1024. [tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables()
  1025. if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
  1026. bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.)
  1027. / (image_size * image_size * 3. * numpy.log(2.)))
  1028. self.bit_per_dim = bit_per_dim
  1029. # OPTIMIZATION
  1030. momentum = 1. - hps.momentum
  1031. decay = 1. - hps.decay
  1032. if hps.optimizer == "adam":
  1033. optimizer = tf.train.AdamOptimizer(
  1034. learning_rate=hps.learning_rate,
  1035. beta1=momentum, beta2=decay, epsilon=1e-08,
  1036. use_locking=False, name="Adam")
  1037. elif hps.optimizer == "rmsprop":
  1038. optimizer = tf.train.RMSPropOptimizer(
  1039. learning_rate=hps.learning_rate, decay=decay,
  1040. momentum=momentum, epsilon=1e-04,
  1041. use_locking=False, name="RMSProp")
  1042. else:
  1043. optimizer = tf.train.MomentumOptimizer(hps.learning_rate,
  1044. momentum=momentum)
  1045. step = tf.get_variable(
  1046. "global_step", [], tf.int64,
  1047. tf.zeros_initializer(),
  1048. trainable=False)
  1049. self.step = step
  1050. grads_and_vars = optimizer.compute_gradients(
  1051. cost + hps.l2_coeff * l2_reg,
  1052. tf.trainable_variables())
  1053. grads, vars_ = zip(*grads_and_vars)
  1054. capped_grads, gradient_norm = tf.clip_by_global_norm(
  1055. grads, clip_norm=hps.clip_gradient)
  1056. gradient_norm = tf.check_numerics(gradient_norm,
  1057. "Gradient norm is NaN or Inf.")
  1058. l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3])
  1059. if not sampling:
  1060. tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, []))
  1061. tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, []))
  1062. tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, []))
  1063. tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), []))
  1064. tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), []))
  1065. tf.summary.scalar(
  1066. "log_diff_var",
  1067. tf.reshape(tf.reduce_mean(tf.square(log_diff))
  1068. - tf.square(tf.reduce_mean(log_diff)), []))
  1069. tf.summary.scalar(
  1070. "prior_ll_var",
  1071. tf.reshape(tf.reduce_mean(tf.square(prior_ll))
  1072. - tf.square(tf.reduce_mean(prior_ll)), []))
  1073. tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), []))
  1074. tf.summary.scalar(
  1075. "l2_z_var",
  1076. tf.reshape(tf.reduce_mean(tf.square(l2_z))
  1077. - tf.square(tf.reduce_mean(l2_z)), []))
  1078. capped_grads_and_vars = zip(capped_grads, vars_)
  1079. self.train_step = optimizer.apply_gradients(
  1080. capped_grads_and_vars, global_step=step)
  1081. # SAMPLING AND VISUALIZATION
  1082. if sampling:
  1083. # SAMPLES
  1084. sample = standard_normal_sample([100] + final_shape)
  1085. sample, _ = decoder(
  1086. input_=sample, hps=hps, n_scale=hps.n_scale,
  1087. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1088. train=True)
  1089. sample = tf.nn.sigmoid(sample)
  1090. sample = tf.clip_by_value(sample, 0, 1) * 255.
  1091. sample = tf.reshape(sample, [100, image_size, image_size, 3])
  1092. sample = tf.transpose(
  1093. tf.reshape(sample, [10, image_size * 10, image_size, 3]),
  1094. [0, 2, 1, 3])
  1095. sample = tf.transpose(
  1096. tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]),
  1097. [0, 2, 1, 3])
  1098. tf.summary.image(
  1099. "samples",
  1100. tf.cast(sample, tf.uint8),
  1101. max_outputs=1)
  1102. # CONCATENATION
  1103. concatenation, _ = encoder(
  1104. input_=logit_x_in, hps=hps,
  1105. n_scale=hps.n_scale,
  1106. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1107. train=False)
  1108. concatenation = tf.reshape(
  1109. concatenation,
  1110. [(side_shown * side_shown), image_size, image_size, 3])
  1111. concatenation = tf.transpose(
  1112. tf.reshape(
  1113. concatenation,
  1114. [side_shown, image_size * side_shown, image_size, 3]),
  1115. [0, 2, 1, 3])
  1116. concatenation = tf.transpose(
  1117. tf.reshape(
  1118. concatenation,
  1119. [1, image_size * side_shown, image_size * side_shown, 3]),
  1120. [0, 2, 1, 3])
  1121. concatenation, _ = decoder(
  1122. input_=concatenation, hps=hps, n_scale=hps.n_scale,
  1123. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1124. train=False)
  1125. concatenation = tf.nn.sigmoid(concatenation) * 255.
  1126. tf.summary.image(
  1127. "concatenation",
  1128. tf.cast(concatenation, tf.uint8),
  1129. max_outputs=1)
  1130. # MANIFOLD
  1131. # Data basis
  1132. z_u, _ = encoder(
  1133. input_=logit_x_in[:8, :, :, :], hps=hps,
  1134. n_scale=hps.n_scale,
  1135. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1136. train=False)
  1137. u_1 = tf.reshape(z_u[0, :, :, :], [-1])
  1138. u_2 = tf.reshape(z_u[1, :, :, :], [-1])
  1139. u_3 = tf.reshape(z_u[2, :, :, :], [-1])
  1140. u_4 = tf.reshape(z_u[3, :, :, :], [-1])
  1141. u_5 = tf.reshape(z_u[4, :, :, :], [-1])
  1142. u_6 = tf.reshape(z_u[5, :, :, :], [-1])
  1143. u_7 = tf.reshape(z_u[6, :, :, :], [-1])
  1144. u_8 = tf.reshape(z_u[7, :, :, :], [-1])
  1145. # 3D dome
  1146. manifold_side = 8
  1147. angle_1 = numpy.arange(manifold_side) * 1. / manifold_side
  1148. angle_2 = numpy.arange(manifold_side) * 1. / manifold_side
  1149. angle_1 *= 2. * numpy.pi
  1150. angle_2 *= 2. * numpy.pi
  1151. angle_1 = angle_1.astype("float32")
  1152. angle_2 = angle_2.astype("float32")
  1153. angle_1 = tf.reshape(angle_1, [1, -1, 1])
  1154. angle_1 += tf.zeros([manifold_side, manifold_side, 1])
  1155. angle_2 = tf.reshape(angle_2, [-1, 1, 1])
  1156. angle_2 += tf.zeros([manifold_side, manifold_side, 1])
  1157. n_angle_3 = 40
  1158. angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3
  1159. angle_3 *= 2 * numpy.pi
  1160. angle_3 = angle_3.astype("float32")
  1161. angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1])
  1162. angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1])
  1163. manifold = tf.cos(angle_1) * (
  1164. tf.cos(angle_2) * (
  1165. tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2)
  1166. + tf.sin(angle_2) * (
  1167. tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4))
  1168. manifold += tf.sin(angle_1) * (
  1169. tf.cos(angle_2) * (
  1170. tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6)
  1171. + tf.sin(angle_2) * (
  1172. tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8))
  1173. manifold = tf.reshape(
  1174. manifold,
  1175. [n_angle_3 * manifold_side * manifold_side] + final_shape)
  1176. manifold, _ = decoder(
  1177. input_=manifold, hps=hps, n_scale=hps.n_scale,
  1178. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1179. train=False)
  1180. manifold = tf.nn.sigmoid(manifold)
  1181. manifold = tf.clip_by_value(manifold, 0, 1) * 255.
  1182. manifold = tf.reshape(
  1183. manifold,
  1184. [n_angle_3,
  1185. manifold_side * manifold_side,
  1186. image_size,
  1187. image_size,
  1188. 3])
  1189. manifold = tf.transpose(
  1190. tf.reshape(
  1191. manifold,
  1192. [n_angle_3, manifold_side,
  1193. image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4])
  1194. manifold = tf.transpose(
  1195. tf.reshape(
  1196. manifold,
  1197. [n_angle_3, image_size * manifold_side,
  1198. image_size * manifold_side, 3]),
  1199. [0, 2, 1, 3])
  1200. manifold = tf.transpose(manifold, [1, 2, 0, 3])
  1201. manifold = tf.reshape(
  1202. manifold,
  1203. [1, image_size * manifold_side,
  1204. image_size * manifold_side, 3 * n_angle_3])
  1205. tf.summary.image(
  1206. "manifold",
  1207. tf.cast(manifold[:, :, :, :3], tf.uint8),
  1208. max_outputs=1)
  1209. # COMPRESSION
  1210. z_complete, _ = encoder(
  1211. input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps,
  1212. n_scale=hps.n_scale,
  1213. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1214. train=False)
  1215. z_compressed_list = [z_complete]
  1216. z_noisy_list = [z_complete]
  1217. z_lost = z_complete
  1218. for scale_idx in xrange(hps.n_scale - 1):
  1219. z_lost = squeeze_2x2_ordered(z_lost)
  1220. z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3)
  1221. z_compressed = z_lost
  1222. z_noisy = z_lost
  1223. for _ in xrange(scale_idx + 1):
  1224. z_compressed = tf.concat_v2(
  1225. [z_compressed, tf.zeros_like(z_compressed)], 3)
  1226. z_compressed = squeeze_2x2_ordered(
  1227. z_compressed, reverse=True)
  1228. z_noisy = tf.concat_v2(
  1229. [z_noisy, tf.random_normal(
  1230. z_noisy.get_shape().as_list())], 3)
  1231. z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True)
  1232. z_compressed_list.append(z_compressed)
  1233. z_noisy_list.append(z_noisy)
  1234. self.z_reduced = z_lost
  1235. z_compressed = tf.concat_v2(z_compressed_list, 0)
  1236. z_noisy = tf.concat_v2(z_noisy_list, 0)
  1237. noisy_images, _ = decoder(
  1238. input_=z_noisy, hps=hps, n_scale=hps.n_scale,
  1239. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1240. train=False)
  1241. compressed_images, _ = decoder(
  1242. input_=z_compressed, hps=hps, n_scale=hps.n_scale,
  1243. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1244. train=False)
  1245. noisy_images = tf.nn.sigmoid(noisy_images)
  1246. compressed_images = tf.nn.sigmoid(compressed_images)
  1247. noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255.
  1248. noisy_images = tf.reshape(
  1249. noisy_images,
  1250. [(hps.n_scale * hps.n_scale), image_size, image_size, 3])
  1251. noisy_images = tf.transpose(
  1252. tf.reshape(
  1253. noisy_images,
  1254. [hps.n_scale, image_size * hps.n_scale, image_size, 3]),
  1255. [0, 2, 1, 3])
  1256. noisy_images = tf.transpose(
  1257. tf.reshape(
  1258. noisy_images,
  1259. [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
  1260. [0, 2, 1, 3])
  1261. tf.summary.image(
  1262. "noise",
  1263. tf.cast(noisy_images, tf.uint8),
  1264. max_outputs=1)
  1265. compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255.
  1266. compressed_images = tf.reshape(
  1267. compressed_images,
  1268. [(hps.n_scale * hps.n_scale), image_size, image_size, 3])
  1269. compressed_images = tf.transpose(
  1270. tf.reshape(
  1271. compressed_images,
  1272. [hps.n_scale, image_size * hps.n_scale, image_size, 3]),
  1273. [0, 2, 1, 3])
  1274. compressed_images = tf.transpose(
  1275. tf.reshape(
  1276. compressed_images,
  1277. [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
  1278. [0, 2, 1, 3])
  1279. tf.summary.image(
  1280. "compression",
  1281. tf.cast(compressed_images, tf.uint8),
  1282. max_outputs=1)
  1283. # SAMPLES x2
  1284. final_shape[0] *= 2
  1285. final_shape[1] *= 2
  1286. big_sample = standard_normal_sample([25] + final_shape)
  1287. big_sample, _ = decoder(
  1288. input_=big_sample, hps=hps, n_scale=hps.n_scale,
  1289. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1290. train=True)
  1291. big_sample = tf.nn.sigmoid(big_sample)
  1292. big_sample = tf.clip_by_value(big_sample, 0, 1) * 255.
  1293. big_sample = tf.reshape(
  1294. big_sample,
  1295. [25, image_size * 2, image_size * 2, 3])
  1296. big_sample = tf.transpose(
  1297. tf.reshape(
  1298. big_sample,
  1299. [5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3])
  1300. big_sample = tf.transpose(
  1301. tf.reshape(
  1302. big_sample,
  1303. [1, image_size * 10, image_size * 10, 3]),
  1304. [0, 2, 1, 3])
  1305. tf.summary.image(
  1306. "big_sample",
  1307. tf.cast(big_sample, tf.uint8),
  1308. max_outputs=1)
  1309. # SAMPLES x10
  1310. final_shape[0] *= 5
  1311. final_shape[1] *= 5
  1312. extra_large = standard_normal_sample([1] + final_shape)
  1313. extra_large, _ = decoder(
  1314. input_=extra_large, hps=hps, n_scale=hps.n_scale,
  1315. use_batch_norm=hps.use_batch_norm, weight_norm=True,
  1316. train=True)
  1317. extra_large = tf.nn.sigmoid(extra_large)
  1318. extra_large = tf.clip_by_value(extra_large, 0, 1) * 255.
  1319. tf.summary.image(
  1320. "extra_large",
  1321. tf.cast(extra_large, tf.uint8),
  1322. max_outputs=1)
  1323. def eval_epoch(self, hps):
  1324. """Evaluate bits/dim."""
  1325. n_eval_dict = {
  1326. "imnet": 50000,
  1327. "lsun": 300,
  1328. "celeba": 19962,
  1329. "svhn": 26032,
  1330. }
  1331. if FLAGS.eval_set_size == 0:
  1332. num_examples_eval = n_eval_dict[FLAGS.dataset]
  1333. else:
  1334. num_examples_eval = FLAGS.eval_set_size
  1335. n_epoch = num_examples_eval / hps.batch_size
  1336. eval_costs = []
  1337. bar_len = 70
  1338. for epoch_idx in xrange(n_epoch):
  1339. n_equal = epoch_idx * bar_len * 1. / n_epoch
  1340. n_equal = numpy.ceil(n_equal)
  1341. n_equal = int(n_equal)
  1342. n_dash = bar_len - n_equal
  1343. progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r"
  1344. print progress_bar,
  1345. cost = self.bit_per_dim.eval()
  1346. eval_costs.append(cost)
  1347. print ""
  1348. return float(numpy.mean(eval_costs))
  1349. def train_model(hps, logdir):
  1350. """Training."""
  1351. with tf.Graph().as_default():
  1352. with tf.device(tf.train.replica_device_setter(0)):
  1353. with tf.variable_scope("model"):
  1354. model = RealNVP(hps)
  1355. saver = tf.train.Saver(tf.global_variables())
  1356. # Build the summary operation from the last tower summaries.
  1357. summary_op = tf.summary.merge_all()
  1358. # Build an initialization operation to run below.
  1359. init = tf.global_variables_initializer()
  1360. # Start running operations on the Graph. allow_soft_placement must be set to
  1361. # True to build towers on GPU, as some of the ops do not have GPU
  1362. # implementations.
  1363. sess = tf.Session(config=tf.ConfigProto(
  1364. allow_soft_placement=True,
  1365. log_device_placement=True))
  1366. sess.run(init)
  1367. ckpt_state = tf.train.get_checkpoint_state(logdir)
  1368. if ckpt_state and ckpt_state.model_checkpoint_path:
  1369. print "Loading file %s" % ckpt_state.model_checkpoint_path
  1370. saver.restore(sess, ckpt_state.model_checkpoint_path)
  1371. # Start the queue runners.
  1372. tf.train.start_queue_runners(sess=sess)
  1373. summary_writer = tf.summary.FileWriter(
  1374. logdir,
  1375. graph=sess.graph)
  1376. local_step = 0
  1377. while True:
  1378. fetches = [model.step, model.bit_per_dim, model.train_step]
  1379. # The chief worker evaluates the summaries every 10 steps.
  1380. should_eval_summaries = local_step % 100 == 0
  1381. if should_eval_summaries:
  1382. fetches += [summary_op]
  1383. start_time = time.time()
  1384. outputs = sess.run(fetches)
  1385. global_step_val = outputs[0]
  1386. loss = outputs[1]
  1387. duration = time.time() - start_time
  1388. assert not numpy.isnan(
  1389. loss), 'Model diverged with loss = NaN'
  1390. if local_step % 10 == 0:
  1391. examples_per_sec = hps.batch_size / float(duration)
  1392. format_str = ('%s: step %d, loss = %.2f '
  1393. '(%.1f examples/sec; %.3f '
  1394. 'sec/batch)')
  1395. print format_str % (datetime.now(), global_step_val, loss,
  1396. examples_per_sec, duration)
  1397. if should_eval_summaries:
  1398. summary_str = outputs[-1]
  1399. summary_writer.add_summary(summary_str, global_step_val)
  1400. # Save the model checkpoint periodically.
  1401. if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps:
  1402. checkpoint_path = os.path.join(logdir, 'model.ckpt')
  1403. saver.save(
  1404. sess,
  1405. checkpoint_path,
  1406. global_step=global_step_val)
  1407. if outputs[0] >= FLAGS.train_steps:
  1408. break
  1409. local_step += 1
  1410. def evaluate(hps, logdir, traindir, subset="valid", return_val=False):
  1411. """Evaluation."""
  1412. hps.batch_size = 100
  1413. with tf.Graph().as_default():
  1414. with tf.device("/cpu:0"):
  1415. with tf.variable_scope("model") as var_scope:
  1416. eval_model = RealNVP(hps)
  1417. summary_writer = tf.summary.FileWriter(logdir)
  1418. var_scope.reuse_variables()
  1419. saver = tf.train.Saver()
  1420. sess = tf.Session(config=tf.ConfigProto(
  1421. allow_soft_placement=True,
  1422. log_device_placement=True))
  1423. tf.train.start_queue_runners(sess)
  1424. previous_global_step = 0 # don"t run eval for step = 0
  1425. with sess.as_default():
  1426. while True:
  1427. ckpt_state = tf.train.get_checkpoint_state(traindir)
  1428. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  1429. print "No model to eval yet at %s" % traindir
  1430. time.sleep(30)
  1431. continue
  1432. print "Loading file %s" % ckpt_state.model_checkpoint_path
  1433. saver.restore(sess, ckpt_state.model_checkpoint_path)
  1434. current_step = tf.train.global_step(sess, eval_model.step)
  1435. if current_step == previous_global_step:
  1436. print "Waiting for the checkpoint to be updated."
  1437. time.sleep(30)
  1438. continue
  1439. previous_global_step = current_step
  1440. print "Evaluating..."
  1441. bit_per_dim = eval_model.eval_epoch(hps)
  1442. print ("Epoch: %d, %s -> %.3f bits/dim"
  1443. % (current_step, subset, bit_per_dim))
  1444. print "Writing summary..."
  1445. summary = tf.Summary()
  1446. summary.value.extend(
  1447. [tf.Summary.Value(
  1448. tag="bit_per_dim",
  1449. simple_value=bit_per_dim)])
  1450. summary_writer.add_summary(summary, current_step)
  1451. if return_val:
  1452. return current_step, bit_per_dim
  1453. def sample_from_model(hps, logdir, traindir):
  1454. """Sampling."""
  1455. hps.batch_size = 100
  1456. with tf.Graph().as_default():
  1457. with tf.device("/cpu:0"):
  1458. with tf.variable_scope("model") as var_scope:
  1459. eval_model = RealNVP(hps, sampling=True)
  1460. summary_writer = tf.summary.FileWriter(logdir)
  1461. var_scope.reuse_variables()
  1462. summary_op = tf.summary.merge_all()
  1463. saver = tf.train.Saver()
  1464. sess = tf.Session(config=tf.ConfigProto(
  1465. allow_soft_placement=True,
  1466. log_device_placement=True))
  1467. coord = tf.train.Coordinator()
  1468. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  1469. previous_global_step = 0 # don"t run eval for step = 0
  1470. initialized = False
  1471. with sess.as_default():
  1472. while True:
  1473. ckpt_state = tf.train.get_checkpoint_state(traindir)
  1474. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  1475. if not initialized:
  1476. print "No model to eval yet at %s" % traindir
  1477. time.sleep(30)
  1478. continue
  1479. else:
  1480. print ("Loading file %s"
  1481. % ckpt_state.model_checkpoint_path)
  1482. saver.restore(sess, ckpt_state.model_checkpoint_path)
  1483. current_step = tf.train.global_step(sess, eval_model.step)
  1484. if current_step == previous_global_step:
  1485. print "Waiting for the checkpoint to be updated."
  1486. time.sleep(30)
  1487. continue
  1488. previous_global_step = current_step
  1489. fetches = [summary_op]
  1490. outputs = sess.run(fetches)
  1491. summary_writer.add_summary(outputs[0], current_step)
  1492. coord.request_stop()
  1493. coord.join(threads)
  1494. def main(unused_argv):
  1495. hps = get_default_hparams().update_config(FLAGS.hpconfig)
  1496. if FLAGS.mode == "train":
  1497. train_model(hps=hps, logdir=FLAGS.logdir)
  1498. elif FLAGS.mode == "sample":
  1499. sample_from_model(hps=hps, logdir=FLAGS.logdir,
  1500. traindir=FLAGS.traindir)
  1501. else:
  1502. hps.batch_size = 100
  1503. evaluate(hps=hps, logdir=FLAGS.logdir,
  1504. traindir=FLAGS.traindir, subset=FLAGS.mode)
  1505. if __name__ == "__main__":
  1506. tf.app.run()