|
@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name,
|
|
|
residual_blocks=residual_blocks,
|
|
|
bottleneck=bottleneck, skip=skip)
|
|
|
mask = tf.mod(mask_channel + mask, 2)
|
|
|
- res = tf.split(axis=res, num_or_size_splits=2, value=3)
|
|
|
+ res = tf.split(axis=3, num_or_size_splits=2, value=res)
|
|
|
shift, log_rescaling = res[-2], res[-1]
|
|
|
scale = variable_on_cpu(
|
|
|
"rescaling_scale", [],
|
|
@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name,
|
|
|
scope.reuse_variables()
|
|
|
|
|
|
if change_bottom:
|
|
|
- input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
|
|
|
+ input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
|
|
|
else:
|
|
|
- canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
|
|
|
+ canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
|
|
|
shape = input_.get_shape().as_list()
|
|
|
batch_size = shape[0]
|
|
|
height = shape[1]
|
|
@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name,
|
|
|
train=train, weight_norm=weight_norm,
|
|
|
residual_blocks=residual_blocks,
|
|
|
bottleneck=bottleneck, skip=skip)
|
|
|
- shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3)
|
|
|
+ shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res)
|
|
|
scale = variable_on_cpu(
|
|
|
"scale", [],
|
|
|
tf.constant_initializer(1.))
|
|
@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name,
|
|
|
scope.reuse_variables()
|
|
|
|
|
|
if change_bottom:
|
|
|
- input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
|
|
|
+ input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
|
|
|
else:
|
|
|
- canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
|
|
|
+ canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
|
|
|
shape = input_.get_shape().as_list()
|
|
|
channels = shape[3]
|
|
|
res = input_
|
|
@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
|
|
|
log_diff_1 = log_diff[:, :, :, :channels]
|
|
|
log_diff_2 = log_diff[:, :, :, channels:]
|
|
|
else:
|
|
|
- res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
|
|
|
- log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
|
|
|
+ res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
|
|
|
+ log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
|
|
|
res_1, inc_log_diff = rec_masked_conv_coupling(
|
|
|
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
|
|
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
|
|
@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
|
|
|
log_diff_1 = log_diff[:, :, :, :channels]
|
|
|
log_diff_2 = log_diff[:, :, :, channels:]
|
|
|
else:
|
|
|
- res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
|
|
|
- log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
|
|
|
+ res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
|
|
|
+ log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
|
|
|
res_1, log_diff_1 = rec_masked_deconv_coupling(
|
|
|
input_=res_1, hps=hps,
|
|
|
scale_idx=scale_idx + 1, n_scale=n_scale,
|
|
@@ -1305,7 +1305,7 @@ class RealNVP(object):
|
|
|
z_lost = z_complete
|
|
|
for scale_idx in xrange(hps.n_scale - 1):
|
|
|
z_lost = squeeze_2x2_ordered(z_lost)
|
|
|
- z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3)
|
|
|
+ z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost)
|
|
|
z_compressed = z_lost
|
|
|
z_noisy = z_lost
|
|
|
for _ in xrange(scale_idx + 1):
|