Преглед на файлове

Manually fixed many occurrences of tf.split

Neal Wu преди 8 години
родител
ревизия
5c53534305
променени са 2 файла, в които са добавени 15 реда и са изтрити 15 реда
  1. 4 4
      neural_gpu/neural_gpu.py
  2. 11 11
      real_nvp/real_nvp_multiscale_dataset.py

+ 4 - 4
neural_gpu/neural_gpu.py

@@ -211,7 +211,7 @@ def reorder_beam(beam_size, batch_size, beam_val, output, is_first,
   # beam_val is [batch_size x beam_size]; let b = batch_size * beam_size
   # decided is len x b x a x b
   # output is b x out_size; step is b x len x a x b;
-  outputs = tf.split(axis=tf.nn.log_softmax(output), num_or_size_splits=beam_size, value=0)
+  outputs = tf.split(axis=0, num_or_size_splits=beam_size, value=tf.nn.log_softmax(output))
   all_beam_vals, all_beam_idx = [], []
   beam_range = 1 if is_first else beam_size
   for i in xrange(beam_range):
@@ -266,9 +266,9 @@ class NeuralGPU(object):
     self.input = tf.placeholder(tf.int32, name="inp")
     self.target = tf.placeholder(tf.int32, name="tgt")
     self.prev_step = tf.placeholder(tf.float32, name="prev_step")
-    gpu_input = tf.split(axis=self.input, num_or_size_splits=num_gpus, value=0)
-    gpu_target = tf.split(axis=self.target, num_or_size_splits=num_gpus, value=0)
-    gpu_prev_step = tf.split(axis=self.prev_step, num_or_size_splits=num_gpus, value=0)
+    gpu_input = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.input)
+    gpu_target = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.target)
+    gpu_prev_step = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.prev_step)
     batch_size = tf.shape(gpu_input[0])[0]
 
     if backward:

+ 11 - 11
real_nvp/real_nvp_multiscale_dataset.py

@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name,
                      residual_blocks=residual_blocks,
                      bottleneck=bottleneck, skip=skip)
         mask = tf.mod(mask_channel + mask, 2)
-        res = tf.split(axis=res, num_or_size_splits=2, value=3)
+        res = tf.split(axis=3, num_or_size_splits=2, value=res)
         shift, log_rescaling = res[-2], res[-1]
         scale = variable_on_cpu(
             "rescaling_scale", [],
@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name,
             scope.reuse_variables()
 
         if change_bottom:
-            input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
+            input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
         else:
-            canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
+            canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
         shape = input_.get_shape().as_list()
         batch_size = shape[0]
         height = shape[1]
@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name,
                      train=train, weight_norm=weight_norm,
                      residual_blocks=residual_blocks,
                      bottleneck=bottleneck, skip=skip)
-        shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3)
+        shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res)
         scale = variable_on_cpu(
             "scale", [],
             tf.constant_initializer(1.))
@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name,
             scope.reuse_variables()
 
         if change_bottom:
-            input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
+            input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
         else:
-            canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
+            canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
         shape = input_.get_shape().as_list()
         channels = shape[3]
         res = input_
@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
                 log_diff_1 = log_diff[:, :, :, :channels]
                 log_diff_2 = log_diff[:, :, :, channels:]
             else:
-                res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
-                log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
+                res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
+                log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
             res_1, inc_log_diff = rec_masked_conv_coupling(
                 input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
                 use_batch_norm=use_batch_norm, weight_norm=weight_norm,
@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
                 log_diff_1 = log_diff[:, :, :, :channels]
                 log_diff_2 = log_diff[:, :, :, channels:]
             else:
-                res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
-                log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
+                res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
+                log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
             res_1, log_diff_1 = rec_masked_deconv_coupling(
                 input_=res_1, hps=hps,
                 scale_idx=scale_idx + 1, n_scale=n_scale,
@@ -1305,7 +1305,7 @@ class RealNVP(object):
             z_lost = z_complete
             for scale_idx in xrange(hps.n_scale - 1):
                 z_lost = squeeze_2x2_ordered(z_lost)
-                z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3)
+                z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost)
                 z_compressed = z_lost
                 z_noisy = z_lost
                 for _ in xrange(scale_idx + 1):