12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Script for training, evaluation and sampling for Real NVP.
- $ python real_nvp_multiscale_dataset.py \
- --alsologtostderr \
- --image_size 64 \
- --hpconfig=n_scale=5,base_dim=8 \
- --dataset imnet \
- --data_path [DATA_PATH]
- """
- import time
- from datetime import datetime
- import os
- import numpy
- import tensorflow as tf
- from tensorflow import gfile
- from real_nvp_utils import (
- batch_norm, batch_norm_log_diff, conv_layer,
- squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll,
- standard_normal_sample, unsqueeze_2x2, variable_on_cpu)
- tf.flags.DEFINE_string("master", "local",
- "BNS name of the TensorFlow master, or local.")
- tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale",
- "Directory to which writes logs.")
- tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale",
- "Directory to which writes logs.")
- tf.flags.DEFINE_integer("train_steps", 1000000000000000000,
- "Number of steps to train for.")
- tf.flags.DEFINE_string("data_path", "", "Path to the data.")
- tf.flags.DEFINE_string("mode", "train",
- "Mode of execution. Must be 'train', "
- "'sample' or 'eval'.")
- tf.flags.DEFINE_string("dataset", "imnet",
- "Dataset used. Must be 'imnet', "
- "'celeba' or 'lsun'.")
- tf.flags.DEFINE_integer("recursion_type", 2,
- "Type of the recursion.")
- tf.flags.DEFINE_integer("image_size", 64,
- "Size of the input image.")
- tf.flags.DEFINE_integer("eval_set_size", 0,
- "Size of evaluation dataset.")
- tf.flags.DEFINE_string(
- "hpconfig", "",
- "A comma separated list of hyperparameters for the model. Format is "
- "hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained "
- "with the specified hyperparameters, filling in missing hyperparameters "
- "from the default_values in |hyper_params|.")
- FLAGS = tf.flags.FLAGS
- class HParams(object):
- """Dictionary of hyperparameters."""
- def __init__(self, **kwargs):
- self.dict_ = kwargs
- self.__dict__.update(self.dict_)
- def update_config(self, in_string):
- """Update the dictionary with a comma separated list."""
- pairs = in_string.split(",")
- pairs = [pair.split("=") for pair in pairs]
- for key, val in pairs:
- self.dict_[key] = type(self.dict_[key])(val)
- self.__dict__.update(self.dict_)
- return self
- def __getitem__(self, key):
- return self.dict_[key]
- def __setitem__(self, key, val):
- self.dict_[key] = val
- self.__dict__.update(self.dict_)
- def get_default_hparams():
- """Get the default hyperparameters."""
- return HParams(
- batch_size=64,
- residual_blocks=2,
- n_couplings=2,
- n_scale=4,
- learning_rate=0.001,
- momentum=1e-1,
- decay=1e-3,
- l2_coeff=0.00005,
- clip_gradient=100.,
- optimizer="adam",
- dropout_mask=0,
- base_dim=32,
- bottleneck=0,
- use_batch_norm=1,
- alternate=1,
- use_aff=1,
- skip=1,
- data_constraint=.9,
- n_opt=0)
- # RESNET UTILS
- def residual_block(input_, dim, name, use_batch_norm=True,
- train=True, weight_norm=True, bottleneck=False):
- """Residual convolutional block."""
- with tf.variable_scope(name):
- res = input_
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_in", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- if bottleneck:
- res = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
- name="h_0", stddev=numpy.sqrt(2. / (dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim,
- name="bn_0", scale=False, train=train,
- epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim,
- dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
- bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_1", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
- name="out", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
- bias=True, weight_norm=weight_norm, scale=True)
- else:
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
- name="h_0", stddev=numpy.sqrt(2. / (dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_0", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim,
- name="out", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None,
- bias=True, weight_norm=weight_norm, scale=True)
- res += input_
- return res
- def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True,
- train=True, weight_norm=True, residual_blocks=5,
- bottleneck=False, skip=True):
- """Residual convolutional network."""
- with tf.variable_scope(name):
- res = input_
- if residual_blocks != 0:
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
- name="h_in", stddev=numpy.sqrt(2. / (dim_in)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=False)
- if skip:
- out = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
- name="skip_in", stddev=numpy.sqrt(2. / (dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=True)
- # residual blocks
- for idx_block in xrange(residual_blocks):
- res = residual_block(res, dim, "block_%d" % idx_block,
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- bottleneck=bottleneck)
- if skip:
- out += conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim,
- name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=True)
- # outputs
- if skip:
- res = out
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_pre_out", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim,
- dim_out=dim_out,
- name="out", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=True)
- else:
- if bottleneck:
- res = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim,
- name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_0", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim,
- dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None,
- bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_1", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out,
- name="out", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=True)
- else:
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim,
- name="h_0", stddev=numpy.sqrt(2. / (dim_in)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=(not use_batch_norm),
- weight_norm=weight_norm, scale=False)
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=dim, name="bn_0", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.nn.relu(res)
- res = conv_layer(
- input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out,
- name="out", stddev=numpy.sqrt(2. / (1. * dim)),
- strides=[1, 1, 1, 1], padding="SAME",
- nonlinearity=None, bias=True,
- weight_norm=weight_norm, scale=True)
- return res
- # COUPLING LAYERS
- # masked convolution implementations
- def masked_conv_aff_coupling(input_, mask_in, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, use_width=1., use_height=1.,
- mask_channel=0., skip=True):
- """Affine coupling with masked convolution."""
- with tf.variable_scope(name) as scope:
- if reverse or (not train):
- scope.reuse_variables()
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- # build mask
- mask = use_width * numpy.arange(width)
- mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
- mask = mask.astype("float32")
- mask = tf.mod(mask_in + mask, 2)
- mask = tf.reshape(mask, [-1, height, width, 1])
- if mask.get_shape().as_list()[0] == 1:
- mask = tf.tile(mask, [batch_size, 1, 1, 1])
- res = input_ * tf.mod(mask_channel + mask, 2)
- # initial input
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=channels, name="bn_in", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res *= 2.
- res = tf.concat_v2([res, -res], 3)
- res = tf.concat_v2([res, mask], 3)
- dim_in = 2. * channels + 1
- res = tf.nn.relu(res)
- res = resnet(input_=res, dim_in=dim_in, dim=dim,
- dim_out=2 * channels,
- name="resnet", use_batch_norm=use_batch_norm,
- train=train, weight_norm=weight_norm,
- residual_blocks=residual_blocks,
- bottleneck=bottleneck, skip=skip)
- mask = tf.mod(mask_channel + mask, 2)
- res = tf.split(axis=res, num_or_size_splits=2, value=3)
- shift, log_rescaling = res[-2], res[-1]
- scale = variable_on_cpu(
- "rescaling_scale", [],
- tf.constant_initializer(0.))
- shift = tf.reshape(
- shift, [batch_size, height, width, channels])
- log_rescaling = tf.reshape(
- log_rescaling, [batch_size, height, width, channels])
- log_rescaling = scale * tf.tanh(log_rescaling)
- if not use_batch_norm:
- scale_shift = variable_on_cpu(
- "scale_shift", [],
- tf.constant_initializer(0.))
- log_rescaling += scale_shift
- shift *= (1. - mask)
- log_rescaling *= (1. - mask)
- if reverse:
- res = input_
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res * (1. - mask), dim=channels, name="bn_out",
- train=False, epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res *= tf.exp(.5 * log_var * (1. - mask))
- res += mean * (1. - mask)
- res *= tf.exp(-log_rescaling)
- res -= shift
- log_diff = -log_rescaling
- if use_batch_norm:
- log_diff += .5 * log_var * (1. - mask)
- else:
- res = input_
- res += shift
- res *= tf.exp(log_rescaling)
- log_diff = log_rescaling
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res * (1. - mask), dim=channels, name="bn_out",
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res -= mean * (1. - mask)
- res *= tf.exp(-.5 * log_var * (1. - mask))
- log_diff -= .5 * log_var * (1. - mask)
- return res, log_diff
- def masked_conv_add_coupling(input_, mask_in, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, use_width=1., use_height=1.,
- mask_channel=0., skip=True):
- """Additive coupling with masked convolution."""
- with tf.variable_scope(name) as scope:
- if reverse or (not train):
- scope.reuse_variables()
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- # build mask
- mask = use_width * numpy.arange(width)
- mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask
- mask = mask.astype("float32")
- mask = tf.mod(mask_in + mask, 2)
- mask = tf.reshape(mask, [-1, height, width, 1])
- if mask.get_shape().as_list()[0] == 1:
- mask = tf.tile(mask, [batch_size, 1, 1, 1])
- res = input_ * tf.mod(mask_channel + mask, 2)
- # initial input
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=channels, name="bn_in", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res *= 2.
- res = tf.concat_v2([res, -res], 3)
- res = tf.concat_v2([res, mask], 3)
- dim_in = 2. * channels + 1
- res = tf.nn.relu(res)
- shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
- name="resnet", use_batch_norm=use_batch_norm,
- train=train, weight_norm=weight_norm,
- residual_blocks=residual_blocks,
- bottleneck=bottleneck, skip=skip)
- mask = tf.mod(mask_channel + mask, 2)
- shift *= (1. - mask)
- # use_batch_norm = False
- if reverse:
- res = input_
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res * (1. - mask),
- dim=channels, name="bn_out", train=False, epsilon=1e-4)
- log_var = tf.log(var)
- res *= tf.exp(.5 * log_var * (1. - mask))
- res += mean * (1. - mask)
- res -= shift
- log_diff = tf.zeros_like(res)
- if use_batch_norm:
- log_diff += .5 * log_var * (1. - mask)
- else:
- res = input_
- res += shift
- log_diff = tf.zeros_like(res)
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res * (1. - mask), dim=channels,
- name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res -= mean * (1. - mask)
- res *= tf.exp(-.5 * log_var * (1. - mask))
- log_diff -= .5 * log_var * (1. - mask)
- return res, log_diff
- def masked_conv_coupling(input_, mask_in, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, use_aff=True,
- use_width=1., use_height=1.,
- mask_channel=0., skip=True):
- """Coupling with masked convolution."""
- if use_aff:
- return masked_conv_aff_coupling(
- input_=input_, mask_in=mask_in, dim=dim, name=name,
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=reverse, residual_blocks=residual_blocks,
- bottleneck=bottleneck, use_width=use_width, use_height=use_height,
- mask_channel=mask_channel, skip=skip)
- else:
- return masked_conv_add_coupling(
- input_=input_, mask_in=mask_in, dim=dim, name=name,
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=reverse, residual_blocks=residual_blocks,
- bottleneck=bottleneck, use_width=use_width, use_height=use_height,
- mask_channel=mask_channel, skip=skip)
- # channel-axis splitting implementations
- def conv_ch_aff_coupling(input_, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, change_bottom=True, skip=True):
- """Affine coupling with channel-wise splitting."""
- with tf.variable_scope(name) as scope:
- if reverse or (not train):
- scope.reuse_variables()
- if change_bottom:
- input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
- else:
- canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- res = input_
- # initial input
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=channels, name="bn_in", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.concat_v2([res, -res], 3)
- dim_in = 2. * channels
- res = tf.nn.relu(res)
- res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels,
- name="resnet", use_batch_norm=use_batch_norm,
- train=train, weight_norm=weight_norm,
- residual_blocks=residual_blocks,
- bottleneck=bottleneck, skip=skip)
- shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3)
- scale = variable_on_cpu(
- "scale", [],
- tf.constant_initializer(1.))
- shift = tf.reshape(
- shift, [batch_size, height, width, channels])
- log_rescaling = tf.reshape(
- log_rescaling, [batch_size, height, width, channels])
- log_rescaling = scale * tf.tanh(log_rescaling)
- if not use_batch_norm:
- scale_shift = variable_on_cpu(
- "scale_shift", [],
- tf.constant_initializer(0.))
- log_rescaling += scale_shift
- if reverse:
- res = canvas
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res, dim=channels, name="bn_out", train=False,
- epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res *= tf.exp(.5 * log_var)
- res += mean
- res *= tf.exp(-log_rescaling)
- res -= shift
- log_diff = -log_rescaling
- if use_batch_norm:
- log_diff += .5 * log_var
- else:
- res = canvas
- res += shift
- res *= tf.exp(log_rescaling)
- log_diff = log_rescaling
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res, dim=channels, name="bn_out", train=train,
- epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res -= mean
- res *= tf.exp(-.5 * log_var)
- log_diff -= .5 * log_var
- if change_bottom:
- res = tf.concat_v2([input_, res], 3)
- log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
- else:
- res = tf.concat_v2([res, input_], 3)
- log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
- return res, log_diff
- def conv_ch_add_coupling(input_, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, change_bottom=True, skip=True):
- """Additive coupling with channel-wise splitting."""
- with tf.variable_scope(name) as scope:
- if reverse or (not train):
- scope.reuse_variables()
- if change_bottom:
- input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
- else:
- canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
- shape = input_.get_shape().as_list()
- channels = shape[3]
- res = input_
- # initial input
- if use_batch_norm:
- res = batch_norm(
- input_=res, dim=channels, name="bn_in", scale=False,
- train=train, epsilon=1e-4, axes=[0, 1, 2])
- res = tf.concat_v2([res, -res], 3)
- dim_in = 2. * channels
- res = tf.nn.relu(res)
- shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels,
- name="resnet", use_batch_norm=use_batch_norm,
- train=train, weight_norm=weight_norm,
- residual_blocks=residual_blocks,
- bottleneck=bottleneck, skip=skip)
- if reverse:
- res = canvas
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res, dim=channels, name="bn_out", train=False,
- epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res *= tf.exp(.5 * log_var)
- res += mean
- res -= shift
- log_diff = tf.zeros_like(res)
- if use_batch_norm:
- log_diff += .5 * log_var
- else:
- res = canvas
- res += shift
- log_diff = tf.zeros_like(res)
- if use_batch_norm:
- mean, var = batch_norm_log_diff(
- input_=res, dim=channels, name="bn_out", train=train,
- epsilon=1e-4, axes=[0, 1, 2])
- log_var = tf.log(var)
- res -= mean
- res *= tf.exp(-.5 * log_var)
- log_diff -= .5 * log_var
- if change_bottom:
- res = tf.concat_v2([input_, res], 3)
- log_diff = tf.concat_v2([tf.zeros_like(log_diff), log_diff], 3)
- else:
- res = tf.concat_v2([res, input_], 3)
- log_diff = tf.concat_v2([log_diff, tf.zeros_like(log_diff)], 3)
- return res, log_diff
- def conv_ch_coupling(input_, dim, name,
- use_batch_norm=True, train=True, weight_norm=True,
- reverse=False, residual_blocks=5,
- bottleneck=False, use_aff=True, change_bottom=True,
- skip=True):
- """Coupling with channel-wise splitting."""
- if use_aff:
- return conv_ch_aff_coupling(
- input_=input_, dim=dim, name=name,
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=reverse, residual_blocks=residual_blocks,
- bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
- else:
- return conv_ch_add_coupling(
- input_=input_, dim=dim, name=name,
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=reverse, residual_blocks=residual_blocks,
- bottleneck=bottleneck, change_bottom=change_bottom, skip=skip)
- # RECURSIVE USE OF COUPLING LAYERS
- def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
- use_batch_norm=True, weight_norm=True,
- train=True):
- """Recursion on coupling layers."""
- shape = input_.get_shape().as_list()
- channels = shape[3]
- residual_blocks = hps.residual_blocks
- base_dim = hps.base_dim
- mask = 1.
- use_aff = hps.use_aff
- res = input_
- skip = hps.skip
- log_diff = tf.zeros_like(input_)
- dim = base_dim
- if FLAGS.recursion_type < 4:
- dim *= 2 ** scale_idx
- with tf.variable_scope("scale_%d" % scale_idx):
- # initial coupling layers
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=mask, dim=dim,
- name="coupling_0",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=1. - mask, dim=dim,
- name="coupling_1",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=mask, dim=dim,
- name="coupling_2",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- if scale_idx < (n_scale - 1):
- with tf.variable_scope("scale_%d" % scale_idx):
- res = squeeze_2x2(res)
- log_diff = squeeze_2x2(log_diff)
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=True, dim=2 * dim,
- name="coupling_4",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=False, dim=2 * dim,
- name="coupling_5",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=True, dim=2 * dim,
- name="coupling_6",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True, skip=skip)
- log_diff += inc_log_diff
- res = unsqueeze_2x2(res)
- log_diff = unsqueeze_2x2(log_diff)
- if FLAGS.recursion_type > 1:
- res = squeeze_2x2_ordered(res)
- log_diff = squeeze_2x2_ordered(log_diff)
- if FLAGS.recursion_type > 2:
- res_1 = res[:, :, :, :channels]
- res_2 = res[:, :, :, channels:]
- log_diff_1 = log_diff[:, :, :, :channels]
- log_diff_2 = log_diff[:, :, :, channels:]
- else:
- res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
- log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
- res_1, inc_log_diff = rec_masked_conv_coupling(
- input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- res = tf.concat_v2([res_1, res_2], 3)
- log_diff_1 += inc_log_diff
- log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
- res = squeeze_2x2_ordered(res, reverse=True)
- log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
- else:
- res = squeeze_2x2_ordered(res)
- log_diff = squeeze_2x2_ordered(log_diff)
- res, inc_log_diff = rec_masked_conv_coupling(
- input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- log_diff += inc_log_diff
- res = squeeze_2x2_ordered(res, reverse=True)
- log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
- else:
- with tf.variable_scope("scale_%d" % scale_idx):
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=1. - mask, dim=dim,
- name="coupling_3",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=False, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- return res, log_diff
- def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
- use_batch_norm=True, weight_norm=True,
- train=True):
- """Recursion on inverting coupling layers."""
- shape = input_.get_shape().as_list()
- channels = shape[3]
- residual_blocks = hps.residual_blocks
- base_dim = hps.base_dim
- mask = 1.
- use_aff = hps.use_aff
- res = input_
- log_diff = tf.zeros_like(input_)
- skip = hps.skip
- dim = base_dim
- if FLAGS.recursion_type < 4:
- dim *= 2 ** scale_idx
- if scale_idx < (n_scale - 1):
- if FLAGS.recursion_type > 1:
- res = squeeze_2x2_ordered(res)
- log_diff = squeeze_2x2_ordered(log_diff)
- if FLAGS.recursion_type > 2:
- res_1 = res[:, :, :, :channels]
- res_2 = res[:, :, :, channels:]
- log_diff_1 = log_diff[:, :, :, :channels]
- log_diff_2 = log_diff[:, :, :, channels:]
- else:
- res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
- log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
- res_1, log_diff_1 = rec_masked_deconv_coupling(
- input_=res_1, hps=hps,
- scale_idx=scale_idx + 1, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- res = tf.concat_v2([res_1, res_2], 3)
- log_diff = tf.concat_v2([log_diff_1, log_diff_2], 3)
- res = squeeze_2x2_ordered(res, reverse=True)
- log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
- else:
- res = squeeze_2x2_ordered(res)
- log_diff = squeeze_2x2_ordered(log_diff)
- res, log_diff = rec_masked_deconv_coupling(
- input_=res, hps=hps,
- scale_idx=scale_idx + 1, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- res = squeeze_2x2_ordered(res, reverse=True)
- log_diff = squeeze_2x2_ordered(log_diff, reverse=True)
- with tf.variable_scope("scale_%d" % scale_idx):
- res = squeeze_2x2(res)
- log_diff = squeeze_2x2(log_diff)
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=True, dim=2 * dim,
- name="coupling_6",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True, skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=False, dim=2 * dim,
- name="coupling_5",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = conv_ch_coupling(
- input_=res,
- change_bottom=True, dim=2 * dim,
- name="coupling_4",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip)
- log_diff += inc_log_diff
- res = unsqueeze_2x2(res)
- log_diff = unsqueeze_2x2(log_diff)
- else:
- with tf.variable_scope("scale_%d" % scale_idx):
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=1. - mask, dim=dim,
- name="coupling_3",
- use_batch_norm=use_batch_norm, train=train,
- weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- with tf.variable_scope("scale_%d" % scale_idx):
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=mask, dim=dim,
- name="coupling_2",
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=True,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=1. - mask, dim=dim,
- name="coupling_1",
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- res, inc_log_diff = masked_conv_coupling(
- input_=res,
- mask_in=mask, dim=dim,
- name="coupling_0",
- use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm,
- reverse=True, residual_blocks=residual_blocks,
- bottleneck=hps.bottleneck, use_aff=use_aff,
- use_width=1., use_height=1., skip=skip)
- log_diff += inc_log_diff
- return res, log_diff
- # ENCODER AND DECODER IMPLEMENTATIONS
- # start the recursions
- def encoder(input_, hps, n_scale, use_batch_norm=True,
- weight_norm=True, train=True):
- """Encoding/gaussianization function."""
- res = input_
- log_diff = tf.zeros_like(input_)
- res, inc_log_diff = rec_masked_conv_coupling(
- input_=res, hps=hps, scale_idx=0, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- log_diff += inc_log_diff
- return res, log_diff
- def decoder(input_, hps, n_scale, use_batch_norm=True,
- weight_norm=True, train=True):
- """Decoding/generator function."""
- res, log_diff = rec_masked_deconv_coupling(
- input_=input_, hps=hps, scale_idx=0, n_scale=n_scale,
- use_batch_norm=use_batch_norm, weight_norm=weight_norm,
- train=train)
- return res, log_diff
- class RealNVP(object):
- """Real NVP model."""
- def __init__(self, hps, sampling=False):
- # DATA TENSOR INSTANTIATION
- device = "/cpu:0"
- if FLAGS.dataset == "imnet":
- with tf.device(
- tf.train.replica_device_setter(0, worker_device=device)):
- filename_queue = tf.train.string_input_producer(
- gfile.Glob(FLAGS.data_path), num_epochs=None)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example,
- features={
- "image_raw": tf.FixedLenFeature([], tf.string),
- })
- image = tf.decode_raw(features["image_raw"], tf.uint8)
- image.set_shape([FLAGS.image_size * FLAGS.image_size * 3])
- image = tf.cast(image, tf.float32)
- if FLAGS.mode == "train":
- images = tf.train.shuffle_batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size,
- # Ensures a minimum amount of shuffling of examples.
- min_after_dequeue=1000)
- else:
- images = tf.train.batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size)
- self.x_orig = x_orig = images
- image_size = FLAGS.image_size
- x_in = tf.reshape(
- x_orig,
- [hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
- x_in = tf.clip_by_value(x_in, 0, 255)
- x_in = (tf.cast(x_in, tf.float32)
- + tf.random_uniform(tf.shape(x_in))) / 256.
- elif FLAGS.dataset == "celeba":
- with tf.device(
- tf.train.replica_device_setter(0, worker_device=device)):
- filename_queue = tf.train.string_input_producer(
- gfile.Glob(FLAGS.data_path), num_epochs=None)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example,
- features={
- "image_raw": tf.FixedLenFeature([], tf.string),
- })
- image = tf.decode_raw(features["image_raw"], tf.uint8)
- image.set_shape([218 * 178 * 3]) # 218, 178
- image = tf.cast(image, tf.float32)
- image = tf.reshape(image, [218, 178, 3])
- image = image[40:188, 15:163, :]
- if FLAGS.mode == "train":
- image = tf.image.random_flip_left_right(image)
- images = tf.train.shuffle_batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size,
- min_after_dequeue=1000)
- else:
- images = tf.train.batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size)
- self.x_orig = x_orig = images
- image_size = 64
- x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3])
- x_in = tf.image.resize_images(
- x_in, [64, 64], method=0, align_corners=False)
- x_in = (tf.cast(x_in, tf.float32)
- + tf.random_uniform(tf.shape(x_in))) / 256.
- elif FLAGS.dataset == "lsun":
- with tf.device(
- tf.train.replica_device_setter(0, worker_device=device)):
- filename_queue = tf.train.string_input_producer(
- gfile.Glob(FLAGS.data_path), num_epochs=None)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example,
- features={
- "image_raw": tf.FixedLenFeature([], tf.string),
- "height": tf.FixedLenFeature([], tf.int64),
- "width": tf.FixedLenFeature([], tf.int64),
- "depth": tf.FixedLenFeature([], tf.int64)
- })
- image = tf.decode_raw(features["image_raw"], tf.uint8)
- height = tf.reshape((features["height"], tf.int64)[0], [1])
- height = tf.cast(height, tf.int32)
- width = tf.reshape((features["width"], tf.int64)[0], [1])
- width = tf.cast(width, tf.int32)
- depth = tf.reshape((features["depth"], tf.int64)[0], [1])
- depth = tf.cast(depth, tf.int32)
- image = tf.reshape(image, tf.concat_v2([height, width, depth], 0))
- image = tf.random_crop(image, [64, 64, 3])
- if FLAGS.mode == "train":
- image = tf.image.random_flip_left_right(image)
- images = tf.train.shuffle_batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size,
- # Ensures a minimum amount of shuffling of examples.
- min_after_dequeue=1000)
- else:
- images = tf.train.batch(
- [image], batch_size=hps.batch_size, num_threads=1,
- capacity=1000 + 3 * hps.batch_size)
- self.x_orig = x_orig = images
- image_size = 64
- x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3])
- x_in = (tf.cast(x_in, tf.float32)
- + tf.random_uniform(tf.shape(x_in))) / 256.
- else:
- raise ValueError("Unknown dataset.")
- x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3])
- side_shown = int(numpy.sqrt(hps.batch_size))
- shown_x = tf.transpose(
- tf.reshape(
- x_in[:(side_shown * side_shown), :, :, :],
- [side_shown, image_size * side_shown, image_size, 3]),
- [0, 2, 1, 3])
- shown_x = tf.transpose(
- tf.reshape(
- shown_x,
- [1, image_size * side_shown, image_size * side_shown, 3]),
- [0, 2, 1, 3]) * 255.
- tf.summary.image(
- "inputs",
- tf.cast(shown_x, tf.uint8),
- max_outputs=1)
- # restrict the data
- FLAGS.image_size = image_size
- data_constraint = hps.data_constraint
- pre_logit_scale = numpy.log(data_constraint)
- pre_logit_scale -= numpy.log(1. - data_constraint)
- pre_logit_scale = tf.cast(pre_logit_scale, tf.float32)
- logit_x_in = 2. * x_in # [0, 2]
- logit_x_in -= 1. # [-1, 1]
- logit_x_in *= data_constraint # [-.9, .9]
- logit_x_in += 1. # [.1, 1.9]
- logit_x_in /= 2. # [.05, .95]
- # logit the data
- logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in)
- transform_cost = tf.reduce_sum(
- tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in)
- - tf.nn.softplus(-pre_logit_scale),
- [1, 2, 3])
- # INFERENCE AND COSTS
- z_out, log_diff = encoder(
- input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=True)
- if FLAGS.mode != "train":
- z_out, log_diff = encoder(
- input_=logit_x_in, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- final_shape = [image_size, image_size, 3]
- prior_ll = standard_normal_ll(z_out)
- prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3])
- log_diff = tf.reduce_sum(log_diff, [1, 2, 3])
- log_diff += transform_cost
- cost = -(prior_ll + log_diff)
- self.x_in = x_in
- self.z_out = z_out
- self.cost = cost = tf.reduce_mean(cost)
- l2_reg = sum(
- [tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables()
- if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
- bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.)
- / (image_size * image_size * 3. * numpy.log(2.)))
- self.bit_per_dim = bit_per_dim
- # OPTIMIZATION
- momentum = 1. - hps.momentum
- decay = 1. - hps.decay
- if hps.optimizer == "adam":
- optimizer = tf.train.AdamOptimizer(
- learning_rate=hps.learning_rate,
- beta1=momentum, beta2=decay, epsilon=1e-08,
- use_locking=False, name="Adam")
- elif hps.optimizer == "rmsprop":
- optimizer = tf.train.RMSPropOptimizer(
- learning_rate=hps.learning_rate, decay=decay,
- momentum=momentum, epsilon=1e-04,
- use_locking=False, name="RMSProp")
- else:
- optimizer = tf.train.MomentumOptimizer(hps.learning_rate,
- momentum=momentum)
- step = tf.get_variable(
- "global_step", [], tf.int64,
- tf.zeros_initializer(),
- trainable=False)
- self.step = step
- grads_and_vars = optimizer.compute_gradients(
- cost + hps.l2_coeff * l2_reg,
- tf.trainable_variables())
- grads, vars_ = zip(*grads_and_vars)
- capped_grads, gradient_norm = tf.clip_by_global_norm(
- grads, clip_norm=hps.clip_gradient)
- gradient_norm = tf.check_numerics(gradient_norm,
- "Gradient norm is NaN or Inf.")
- l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3])
- if not sampling:
- tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, []))
- tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, []))
- tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, []))
- tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), []))
- tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), []))
- tf.summary.scalar(
- "log_diff_var",
- tf.reshape(tf.reduce_mean(tf.square(log_diff))
- - tf.square(tf.reduce_mean(log_diff)), []))
- tf.summary.scalar(
- "prior_ll_var",
- tf.reshape(tf.reduce_mean(tf.square(prior_ll))
- - tf.square(tf.reduce_mean(prior_ll)), []))
- tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), []))
- tf.summary.scalar(
- "l2_z_var",
- tf.reshape(tf.reduce_mean(tf.square(l2_z))
- - tf.square(tf.reduce_mean(l2_z)), []))
- capped_grads_and_vars = zip(capped_grads, vars_)
- self.train_step = optimizer.apply_gradients(
- capped_grads_and_vars, global_step=step)
- # SAMPLING AND VISUALIZATION
- if sampling:
- # SAMPLES
- sample = standard_normal_sample([100] + final_shape)
- sample, _ = decoder(
- input_=sample, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=True)
- sample = tf.nn.sigmoid(sample)
- sample = tf.clip_by_value(sample, 0, 1) * 255.
- sample = tf.reshape(sample, [100, image_size, image_size, 3])
- sample = tf.transpose(
- tf.reshape(sample, [10, image_size * 10, image_size, 3]),
- [0, 2, 1, 3])
- sample = tf.transpose(
- tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]),
- [0, 2, 1, 3])
- tf.summary.image(
- "samples",
- tf.cast(sample, tf.uint8),
- max_outputs=1)
- # CONCATENATION
- concatenation, _ = encoder(
- input_=logit_x_in, hps=hps,
- n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- concatenation = tf.reshape(
- concatenation,
- [(side_shown * side_shown), image_size, image_size, 3])
- concatenation = tf.transpose(
- tf.reshape(
- concatenation,
- [side_shown, image_size * side_shown, image_size, 3]),
- [0, 2, 1, 3])
- concatenation = tf.transpose(
- tf.reshape(
- concatenation,
- [1, image_size * side_shown, image_size * side_shown, 3]),
- [0, 2, 1, 3])
- concatenation, _ = decoder(
- input_=concatenation, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- concatenation = tf.nn.sigmoid(concatenation) * 255.
- tf.summary.image(
- "concatenation",
- tf.cast(concatenation, tf.uint8),
- max_outputs=1)
- # MANIFOLD
- # Data basis
- z_u, _ = encoder(
- input_=logit_x_in[:8, :, :, :], hps=hps,
- n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- u_1 = tf.reshape(z_u[0, :, :, :], [-1])
- u_2 = tf.reshape(z_u[1, :, :, :], [-1])
- u_3 = tf.reshape(z_u[2, :, :, :], [-1])
- u_4 = tf.reshape(z_u[3, :, :, :], [-1])
- u_5 = tf.reshape(z_u[4, :, :, :], [-1])
- u_6 = tf.reshape(z_u[5, :, :, :], [-1])
- u_7 = tf.reshape(z_u[6, :, :, :], [-1])
- u_8 = tf.reshape(z_u[7, :, :, :], [-1])
- # 3D dome
- manifold_side = 8
- angle_1 = numpy.arange(manifold_side) * 1. / manifold_side
- angle_2 = numpy.arange(manifold_side) * 1. / manifold_side
- angle_1 *= 2. * numpy.pi
- angle_2 *= 2. * numpy.pi
- angle_1 = angle_1.astype("float32")
- angle_2 = angle_2.astype("float32")
- angle_1 = tf.reshape(angle_1, [1, -1, 1])
- angle_1 += tf.zeros([manifold_side, manifold_side, 1])
- angle_2 = tf.reshape(angle_2, [-1, 1, 1])
- angle_2 += tf.zeros([manifold_side, manifold_side, 1])
- n_angle_3 = 40
- angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3
- angle_3 *= 2 * numpy.pi
- angle_3 = angle_3.astype("float32")
- angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1])
- angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1])
- manifold = tf.cos(angle_1) * (
- tf.cos(angle_2) * (
- tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2)
- + tf.sin(angle_2) * (
- tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4))
- manifold += tf.sin(angle_1) * (
- tf.cos(angle_2) * (
- tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6)
- + tf.sin(angle_2) * (
- tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8))
- manifold = tf.reshape(
- manifold,
- [n_angle_3 * manifold_side * manifold_side] + final_shape)
- manifold, _ = decoder(
- input_=manifold, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- manifold = tf.nn.sigmoid(manifold)
- manifold = tf.clip_by_value(manifold, 0, 1) * 255.
- manifold = tf.reshape(
- manifold,
- [n_angle_3,
- manifold_side * manifold_side,
- image_size,
- image_size,
- 3])
- manifold = tf.transpose(
- tf.reshape(
- manifold,
- [n_angle_3, manifold_side,
- image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4])
- manifold = tf.transpose(
- tf.reshape(
- manifold,
- [n_angle_3, image_size * manifold_side,
- image_size * manifold_side, 3]),
- [0, 2, 1, 3])
- manifold = tf.transpose(manifold, [1, 2, 0, 3])
- manifold = tf.reshape(
- manifold,
- [1, image_size * manifold_side,
- image_size * manifold_side, 3 * n_angle_3])
- tf.summary.image(
- "manifold",
- tf.cast(manifold[:, :, :, :3], tf.uint8),
- max_outputs=1)
- # COMPRESSION
- z_complete, _ = encoder(
- input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps,
- n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- z_compressed_list = [z_complete]
- z_noisy_list = [z_complete]
- z_lost = z_complete
- for scale_idx in xrange(hps.n_scale - 1):
- z_lost = squeeze_2x2_ordered(z_lost)
- z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3)
- z_compressed = z_lost
- z_noisy = z_lost
- for _ in xrange(scale_idx + 1):
- z_compressed = tf.concat_v2(
- [z_compressed, tf.zeros_like(z_compressed)], 3)
- z_compressed = squeeze_2x2_ordered(
- z_compressed, reverse=True)
- z_noisy = tf.concat_v2(
- [z_noisy, tf.random_normal(
- z_noisy.get_shape().as_list())], 3)
- z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True)
- z_compressed_list.append(z_compressed)
- z_noisy_list.append(z_noisy)
- self.z_reduced = z_lost
- z_compressed = tf.concat_v2(z_compressed_list, 0)
- z_noisy = tf.concat_v2(z_noisy_list, 0)
- noisy_images, _ = decoder(
- input_=z_noisy, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- compressed_images, _ = decoder(
- input_=z_compressed, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=False)
- noisy_images = tf.nn.sigmoid(noisy_images)
- compressed_images = tf.nn.sigmoid(compressed_images)
- noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255.
- noisy_images = tf.reshape(
- noisy_images,
- [(hps.n_scale * hps.n_scale), image_size, image_size, 3])
- noisy_images = tf.transpose(
- tf.reshape(
- noisy_images,
- [hps.n_scale, image_size * hps.n_scale, image_size, 3]),
- [0, 2, 1, 3])
- noisy_images = tf.transpose(
- tf.reshape(
- noisy_images,
- [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
- [0, 2, 1, 3])
- tf.summary.image(
- "noise",
- tf.cast(noisy_images, tf.uint8),
- max_outputs=1)
- compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255.
- compressed_images = tf.reshape(
- compressed_images,
- [(hps.n_scale * hps.n_scale), image_size, image_size, 3])
- compressed_images = tf.transpose(
- tf.reshape(
- compressed_images,
- [hps.n_scale, image_size * hps.n_scale, image_size, 3]),
- [0, 2, 1, 3])
- compressed_images = tf.transpose(
- tf.reshape(
- compressed_images,
- [1, image_size * hps.n_scale, image_size * hps.n_scale, 3]),
- [0, 2, 1, 3])
- tf.summary.image(
- "compression",
- tf.cast(compressed_images, tf.uint8),
- max_outputs=1)
- # SAMPLES x2
- final_shape[0] *= 2
- final_shape[1] *= 2
- big_sample = standard_normal_sample([25] + final_shape)
- big_sample, _ = decoder(
- input_=big_sample, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=True)
- big_sample = tf.nn.sigmoid(big_sample)
- big_sample = tf.clip_by_value(big_sample, 0, 1) * 255.
- big_sample = tf.reshape(
- big_sample,
- [25, image_size * 2, image_size * 2, 3])
- big_sample = tf.transpose(
- tf.reshape(
- big_sample,
- [5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3])
- big_sample = tf.transpose(
- tf.reshape(
- big_sample,
- [1, image_size * 10, image_size * 10, 3]),
- [0, 2, 1, 3])
- tf.summary.image(
- "big_sample",
- tf.cast(big_sample, tf.uint8),
- max_outputs=1)
- # SAMPLES x10
- final_shape[0] *= 5
- final_shape[1] *= 5
- extra_large = standard_normal_sample([1] + final_shape)
- extra_large, _ = decoder(
- input_=extra_large, hps=hps, n_scale=hps.n_scale,
- use_batch_norm=hps.use_batch_norm, weight_norm=True,
- train=True)
- extra_large = tf.nn.sigmoid(extra_large)
- extra_large = tf.clip_by_value(extra_large, 0, 1) * 255.
- tf.summary.image(
- "extra_large",
- tf.cast(extra_large, tf.uint8),
- max_outputs=1)
- def eval_epoch(self, hps):
- """Evaluate bits/dim."""
- n_eval_dict = {
- "imnet": 50000,
- "lsun": 300,
- "celeba": 19962,
- "svhn": 26032,
- }
- if FLAGS.eval_set_size == 0:
- num_examples_eval = n_eval_dict[FLAGS.dataset]
- else:
- num_examples_eval = FLAGS.eval_set_size
- n_epoch = num_examples_eval / hps.batch_size
- eval_costs = []
- bar_len = 70
- for epoch_idx in xrange(n_epoch):
- n_equal = epoch_idx * bar_len * 1. / n_epoch
- n_equal = numpy.ceil(n_equal)
- n_equal = int(n_equal)
- n_dash = bar_len - n_equal
- progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r"
- print progress_bar,
- cost = self.bit_per_dim.eval()
- eval_costs.append(cost)
- print ""
- return float(numpy.mean(eval_costs))
- def train_model(hps, logdir):
- """Training."""
- with tf.Graph().as_default():
- with tf.device(tf.train.replica_device_setter(0)):
- with tf.variable_scope("model"):
- model = RealNVP(hps)
- saver = tf.train.Saver(tf.global_variables())
- # Build the summary operation from the last tower summaries.
- summary_op = tf.summary.merge_all()
- # Build an initialization operation to run below.
- init = tf.global_variables_initializer()
- # Start running operations on the Graph. allow_soft_placement must be set to
- # True to build towers on GPU, as some of the ops do not have GPU
- # implementations.
- sess = tf.Session(config=tf.ConfigProto(
- allow_soft_placement=True,
- log_device_placement=True))
- sess.run(init)
- ckpt_state = tf.train.get_checkpoint_state(logdir)
- if ckpt_state and ckpt_state.model_checkpoint_path:
- print "Loading file %s" % ckpt_state.model_checkpoint_path
- saver.restore(sess, ckpt_state.model_checkpoint_path)
- # Start the queue runners.
- tf.train.start_queue_runners(sess=sess)
- summary_writer = tf.summary.FileWriter(
- logdir,
- graph=sess.graph)
- local_step = 0
- while True:
- fetches = [model.step, model.bit_per_dim, model.train_step]
- # The chief worker evaluates the summaries every 10 steps.
- should_eval_summaries = local_step % 100 == 0
- if should_eval_summaries:
- fetches += [summary_op]
- start_time = time.time()
- outputs = sess.run(fetches)
- global_step_val = outputs[0]
- loss = outputs[1]
- duration = time.time() - start_time
- assert not numpy.isnan(
- loss), 'Model diverged with loss = NaN'
- if local_step % 10 == 0:
- examples_per_sec = hps.batch_size / float(duration)
- format_str = ('%s: step %d, loss = %.2f '
- '(%.1f examples/sec; %.3f '
- 'sec/batch)')
- print format_str % (datetime.now(), global_step_val, loss,
- examples_per_sec, duration)
- if should_eval_summaries:
- summary_str = outputs[-1]
- summary_writer.add_summary(summary_str, global_step_val)
- # Save the model checkpoint periodically.
- if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps:
- checkpoint_path = os.path.join(logdir, 'model.ckpt')
- saver.save(
- sess,
- checkpoint_path,
- global_step=global_step_val)
- if outputs[0] >= FLAGS.train_steps:
- break
- local_step += 1
- def evaluate(hps, logdir, traindir, subset="valid", return_val=False):
- """Evaluation."""
- hps.batch_size = 100
- with tf.Graph().as_default():
- with tf.device("/cpu:0"):
- with tf.variable_scope("model") as var_scope:
- eval_model = RealNVP(hps)
- summary_writer = tf.summary.FileWriter(logdir)
- var_scope.reuse_variables()
- saver = tf.train.Saver()
- sess = tf.Session(config=tf.ConfigProto(
- allow_soft_placement=True,
- log_device_placement=True))
- tf.train.start_queue_runners(sess)
- previous_global_step = 0 # don"t run eval for step = 0
- with sess.as_default():
- while True:
- ckpt_state = tf.train.get_checkpoint_state(traindir)
- if not (ckpt_state and ckpt_state.model_checkpoint_path):
- print "No model to eval yet at %s" % traindir
- time.sleep(30)
- continue
- print "Loading file %s" % ckpt_state.model_checkpoint_path
- saver.restore(sess, ckpt_state.model_checkpoint_path)
- current_step = tf.train.global_step(sess, eval_model.step)
- if current_step == previous_global_step:
- print "Waiting for the checkpoint to be updated."
- time.sleep(30)
- continue
- previous_global_step = current_step
- print "Evaluating..."
- bit_per_dim = eval_model.eval_epoch(hps)
- print ("Epoch: %d, %s -> %.3f bits/dim"
- % (current_step, subset, bit_per_dim))
- print "Writing summary..."
- summary = tf.Summary()
- summary.value.extend(
- [tf.Summary.Value(
- tag="bit_per_dim",
- simple_value=bit_per_dim)])
- summary_writer.add_summary(summary, current_step)
- if return_val:
- return current_step, bit_per_dim
- def sample_from_model(hps, logdir, traindir):
- """Sampling."""
- hps.batch_size = 100
- with tf.Graph().as_default():
- with tf.device("/cpu:0"):
- with tf.variable_scope("model") as var_scope:
- eval_model = RealNVP(hps, sampling=True)
- summary_writer = tf.summary.FileWriter(logdir)
- var_scope.reuse_variables()
- summary_op = tf.summary.merge_all()
- saver = tf.train.Saver()
- sess = tf.Session(config=tf.ConfigProto(
- allow_soft_placement=True,
- log_device_placement=True))
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- previous_global_step = 0 # don"t run eval for step = 0
- initialized = False
- with sess.as_default():
- while True:
- ckpt_state = tf.train.get_checkpoint_state(traindir)
- if not (ckpt_state and ckpt_state.model_checkpoint_path):
- if not initialized:
- print "No model to eval yet at %s" % traindir
- time.sleep(30)
- continue
- else:
- print ("Loading file %s"
- % ckpt_state.model_checkpoint_path)
- saver.restore(sess, ckpt_state.model_checkpoint_path)
- current_step = tf.train.global_step(sess, eval_model.step)
- if current_step == previous_global_step:
- print "Waiting for the checkpoint to be updated."
- time.sleep(30)
- continue
- previous_global_step = current_step
- fetches = [summary_op]
- outputs = sess.run(fetches)
- summary_writer.add_summary(outputs[0], current_step)
- coord.request_stop()
- coord.join(threads)
- def main(unused_argv):
- hps = get_default_hparams().update_config(FLAGS.hpconfig)
- if FLAGS.mode == "train":
- train_model(hps=hps, logdir=FLAGS.logdir)
- elif FLAGS.mode == "sample":
- sample_from_model(hps=hps, logdir=FLAGS.logdir,
- traindir=FLAGS.traindir)
- else:
- hps.batch_size = 100
- evaluate(hps=hps, logdir=FLAGS.logdir,
- traindir=FLAGS.traindir, subset=FLAGS.mode)
- if __name__ == "__main__":
- tf.app.run()
|