Browse Source

Convert tf.op_scope to tf.name_scope, plus a few other 1.0 upgrade changes

Neal Wu 8 years ago
parent
commit
3f74c7b419

+ 6 - 3
differential_privacy/dp_sgd/dp_optimizer/utils.py

@@ -233,7 +233,8 @@ def BatchClipByL2norm(t, upper_bound, name=None):
   """
 
   assert upper_bound > 0
-  with tf.op_scope([t, upper_bound], name, "batch_clip_by_l2norm") as name:
+  with tf.name_scope(values=[t, upper_bound], name=name,
+                     default_name="batch_clip_by_l2norm") as name:
     saved_shape = tf.shape(t)
     batch_size = tf.slice(saved_shape, [0], [1])
     t2 = tf.reshape(t, tf.concat(axis=0, values=[batch_size, [-1]]))
@@ -264,7 +265,8 @@ def SoftThreshold(t, threshold_ratio, name=None):
   """
 
   assert threshold_ratio >= 0
-  with tf.op_scope([t, threshold_ratio], name, "soft_thresholding") as name:
+  with tf.name_scope(values=[t, threshold_ratio], name=name,
+                     default_name="soft_thresholding") as name:
     saved_shape = tf.shape(t)
     t2 = tf.reshape(t, tf.concat(axis=0, values=[tf.slice(saved_shape, [0], [1]), -1]))
     t_abs = tf.abs(t2)
@@ -286,7 +288,8 @@ def AddGaussianNoise(t, sigma, name=None):
     the noisy tensor.
   """
 
-  with tf.op_scope([t, sigma], name, "add_gaussian_noise") as name:
+  with tf.name_scope(values=[t, sigma], name=name,
+                     default_name="add_gaussian_noise") as name:
     noisy_t = t + tf.random_normal(tf.shape(t), stddev=sigma)
   return noisy_t
 

+ 11 - 8
inception/inception/image_processing.py

@@ -142,11 +142,12 @@ def decode_jpeg(image_buffer, scope=None):
 
   Args:
     image_buffer: scalar string Tensor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
   Returns:
     3-D float Tensor with values ranging from [0, 1).
   """
-  with tf.op_scope([image_buffer], scope, 'decode_jpeg'):
+  with tf.name_scope(values=[image_buffer], name=scope,
+                     default_name='decode_jpeg'):
     # Decode the string as an RGB JPEG.
     # Note that the resulting image contains an unknown height and width
     # that is set dynamically by decode_jpeg. In other words, the height
@@ -171,11 +172,11 @@ def distort_color(image, thread_id=0, scope=None):
   Args:
     image: Tensor containing single image.
     thread_id: preprocessing thread ID.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
   Returns:
     color-distorted image
   """
-  with tf.op_scope([image], scope, 'distort_color'):
+  with tf.name_scope(values=[image], name=scope, default_name='distort_color'):
     color_ordering = thread_id % 2
 
     if color_ordering == 0:
@@ -209,11 +210,12 @@ def distort_image(image, height, width, bbox, thread_id=0, scope=None):
       where each coordinate is [0, 1) and the coordinates are arranged
       as [ymin, xmin, ymax, xmax].
     thread_id: integer indicating the preprocessing thread.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
   Returns:
     3-D float Tensor of distorted image used for training.
   """
-  with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
+  with tf.name_scope(values=[image, height, width, bbox], name=scope,
+                     default_name='distort_image'):
     # Each bounding box has shape [1, num_boxes, box coords] and
     # the coordinates are ordered [ymin, xmin, ymax, xmax].
 
@@ -281,11 +283,12 @@ def eval_image(image, height, width, scope=None):
     image: 3-D float Tensor
     height: integer
     width: integer
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
   Returns:
     3-D float Tensor of prepared image.
   """
-  with tf.op_scope([image, height, width], scope, 'eval_image'):
+  with tf.name_scope(values=[image, height, width], name=scope,
+                     default_name='eval_image'):
     # Crop the central region of the image with an area containing 87.5% of
     # the original image.
     image = tf.image.central_crop(image, central_fraction=0.875)

+ 4 - 3
learning_to_remember_rare_events/memory.py

@@ -151,8 +151,9 @@ class Memory(object):
 
     if output_given and use_recent_idx:  # add at least one correct memory
       most_recent_hint_idx = tf.gather(self.recent_idx, intended_output)
-      hint_pool_idxs = tf.concat([hint_pool_idxs,
-                                  tf.expand_dims(most_recent_hint_idx, 1)], 1)
+      hint_pool_idxs = tf.concat(
+          axis=1,
+          values=[hint_pool_idxs, tf.expand_dims(most_recent_hint_idx, 1)])
     choose_k = tf.shape(hint_pool_idxs)[1]
 
     with tf.device(self.var_cache_device):
@@ -351,7 +352,7 @@ class LSHMemory(Memory):
             self.memory_size - 1), 0)
         for i, idxs in enumerate(hash_slot_idxs)]
 
-    return tf.concat(hint_pool_idxs, 1)
+    return tf.concat(axis=1, values=hint_pool_idxs)
 
   def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                      batch_size, use_recent_idx, intended_output):

+ 1 - 1
learning_to_remember_rare_events/train.py

@@ -168,7 +168,7 @@ class Trainer(object):
     self.model.setup()
 
     sess = tf.Session()
-    sess.run(tf.initialize_all_variables())
+    sess.run(tf.global_variables_initializer())
 
     saver = tf.train.Saver(max_to_keep=10)
     ckpt = None

+ 3 - 2
swivel/swivel.py

@@ -307,8 +307,9 @@ class SwivelModel(object):
 
     with tf.device('/cpu:0'):
       # ===== MERGE LOSSES =====
-      l2_loss = tf.reduce_mean(tf.concat(l2_losses, 0), 0, name="l2_loss")
-      sigmoid_loss = tf.reduce_mean(tf.concat(sigmoid_losses, 0), 0,
+      l2_loss = tf.reduce_mean(tf.concat(axis=0, values=l2_losses), 0,
+                               name="l2_loss")
+      sigmoid_loss = tf.reduce_mean(tf.concat(axis=0, values=sigmoid_losses), 0,
                                     name="sigmoid_loss")
       self.loss = l2_loss + sigmoid_loss
       average = tf.train.ExponentialMovingAverage(0.8, self.global_step)

+ 4 - 3
textsum/seq2seq_lib.py

@@ -42,8 +42,8 @@ def sequence_loss_by_example(inputs, targets, weights, loss_function,
   if len(targets) != len(inputs) or len(weights) != len(inputs):
     raise ValueError('Lengths of logits, weights, and targets must be the same '
                      '%d, %d, %d.' % (len(inputs), len(weights), len(targets)))
-  with tf.op_scope(inputs + targets + weights, name,
-                   'sequence_loss_by_example'):
+  with tf.name_scope(values=inputs + targets + weights, name=name,
+                     default_name='sequence_loss_by_example'):
     log_perp_list = []
     for inp, target, weight in zip(inputs, targets, weights):
       crossent = loss_function(inp, target)
@@ -77,7 +77,8 @@ def sampled_sequence_loss(inputs, targets, weights, loss_function,
   Raises:
     ValueError: If len(inputs) is different from len(targets) or len(weights).
   """
-  with tf.op_scope(inputs + targets + weights, name, 'sampled_sequence_loss'):
+  with tf.name_scope(values=inputs + targets + weights, name=name,
+                     default_name='sampled_sequence_loss'):
     cost = tf.reduce_sum(sequence_loss_by_example(
         inputs, targets, weights, loss_function,
         average_across_timesteps=average_across_timesteps))