utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utils for building and training NN models.
  16. """
  17. from __future__ import division
  18. import math
  19. import numpy
  20. import tensorflow as tf
  21. class LayerParameters(object):
  22. """class that defines a non-conv layer."""
  23. def __init__(self):
  24. self.name = ""
  25. self.num_units = 0
  26. self._with_bias = False
  27. self.relu = False
  28. self.gradient_l2norm_bound = 0.0
  29. self.bias_gradient_l2norm_bound = 0.0
  30. self.trainable = True
  31. self.weight_decay = 0.0
  32. class ConvParameters(object):
  33. """class that defines a conv layer."""
  34. def __init__(self):
  35. self.patch_size = 5
  36. self.stride = 1
  37. self.in_channels = 1
  38. self.out_channels = 0
  39. self.with_bias = True
  40. self.relu = True
  41. self.max_pool = True
  42. self.max_pool_size = 2
  43. self.max_pool_stride = 2
  44. self.trainable = False
  45. self.in_size = 28
  46. self.name = ""
  47. self.num_outputs = 0
  48. self.bias_stddev = 0.1
  49. # Parameters for a layered neural network.
  50. class NetworkParameters(object):
  51. """class that define the overall model structure."""
  52. def __init__(self):
  53. self.input_size = 0
  54. self.projection_type = 'NONE' # NONE, RANDOM, PCA
  55. self.projection_dimensions = 0
  56. self.default_gradient_l2norm_bound = 0.0
  57. self.layer_parameters = [] # List of LayerParameters
  58. self.conv_parameters = [] # List of ConvParameters
  59. def GetTensorOpName(x):
  60. """Get the name of the op that created a tensor.
  61. Useful for naming related tensors, as ':' in name field of op is not permitted
  62. Args:
  63. x: the input tensor.
  64. Returns:
  65. the name of the op.
  66. """
  67. t = x.name.rsplit(":", 1)
  68. if len(t) == 1:
  69. return x.name
  70. else:
  71. return t[0]
  72. def BuildNetwork(inputs, network_parameters):
  73. """Build a network using the given parameters.
  74. Args:
  75. inputs: a Tensor of floats containing the input data.
  76. network_parameters: NetworkParameters object
  77. that describes the parameters for the network.
  78. Returns:
  79. output, training_parameters: where the outputs (a tensor) is the output
  80. of the network, and training_parameters (a dictionary that maps the
  81. name of each variable to a dictionary of parameters) is the parameters
  82. used during training.
  83. """
  84. training_parameters = {}
  85. num_inputs = network_parameters.input_size
  86. outputs = inputs
  87. projection = None
  88. # First apply convolutions, if needed
  89. for conv_param in network_parameters.conv_parameters:
  90. outputs = tf.reshape(
  91. outputs,
  92. [-1, conv_param.in_size, conv_param.in_size,
  93. conv_param.in_channels])
  94. conv_weights_name = "%s_conv_weight" % (conv_param.name)
  95. conv_bias_name = "%s_conv_bias" % (conv_param.name)
  96. conv_std_dev = 1.0 / (conv_param.patch_size
  97. * math.sqrt(conv_param.in_channels))
  98. conv_weights = tf.Variable(
  99. tf.truncated_normal([conv_param.patch_size,
  100. conv_param.patch_size,
  101. conv_param.in_channels,
  102. conv_param.out_channels],
  103. stddev=conv_std_dev),
  104. trainable=conv_param.trainable,
  105. name=conv_weights_name)
  106. conv_bias = tf.Variable(
  107. tf.truncated_normal([conv_param.out_channels],
  108. stddev=conv_param.bias_stddev),
  109. trainable=conv_param.trainable,
  110. name=conv_bias_name)
  111. training_parameters[conv_weights_name] = {}
  112. training_parameters[conv_bias_name] = {}
  113. conv = tf.nn.conv2d(outputs, conv_weights,
  114. strides=[1, conv_param.stride,
  115. conv_param.stride, 1],
  116. padding="SAME")
  117. relud = tf.nn.relu(conv + conv_bias)
  118. mpd = tf.nn.max_pool(relud, ksize=[1,
  119. conv_param.max_pool_size,
  120. conv_param.max_pool_size, 1],
  121. strides=[1, conv_param.max_pool_stride,
  122. conv_param.max_pool_stride, 1],
  123. padding="SAME")
  124. outputs = mpd
  125. num_inputs = conv_param.num_outputs
  126. # this should equal
  127. # in_size * in_size * out_channels / (stride * max_pool_stride)
  128. # once all the convs are done, reshape to make it flat
  129. outputs = tf.reshape(outputs, [-1, num_inputs])
  130. # Now project, if needed
  131. if network_parameters.projection_type is not "NONE":
  132. projection = tf.Variable(tf.truncated_normal(
  133. [num_inputs, network_parameters.projection_dimensions],
  134. stddev=1.0 / math.sqrt(num_inputs)), trainable=False, name="projection")
  135. num_inputs = network_parameters.projection_dimensions
  136. outputs = tf.matmul(outputs, projection)
  137. # Now apply any other layers
  138. for layer_parameters in network_parameters.layer_parameters:
  139. num_units = layer_parameters.num_units
  140. hidden_weights_name = "%s_weight" % (layer_parameters.name)
  141. hidden_weights = tf.Variable(
  142. tf.truncated_normal([num_inputs, num_units],
  143. stddev=1.0 / math.sqrt(num_inputs)),
  144. name=hidden_weights_name, trainable=layer_parameters.trainable)
  145. training_parameters[hidden_weights_name] = {}
  146. if layer_parameters.gradient_l2norm_bound:
  147. training_parameters[hidden_weights_name]["gradient_l2norm_bound"] = (
  148. layer_parameters.gradient_l2norm_bound)
  149. if layer_parameters.weight_decay:
  150. training_parameters[hidden_weights_name]["weight_decay"] = (
  151. layer_parameters.weight_decay)
  152. outputs = tf.matmul(outputs, hidden_weights)
  153. if layer_parameters.with_bias:
  154. hidden_biases_name = "%s_bias" % (layer_parameters.name)
  155. hidden_biases = tf.Variable(tf.zeros([num_units]),
  156. name=hidden_biases_name)
  157. training_parameters[hidden_biases_name] = {}
  158. if layer_parameters.bias_gradient_l2norm_bound:
  159. training_parameters[hidden_biases_name][
  160. "bias_gradient_l2norm_bound"] = (
  161. layer_parameters.bias_gradient_l2norm_bound)
  162. outputs += hidden_biases
  163. if layer_parameters.relu:
  164. outputs = tf.nn.relu(outputs)
  165. # num_inputs for the next layer is num_units in the current layer.
  166. num_inputs = num_units
  167. return outputs, projection, training_parameters
  168. def VaryRate(start, end, saturate_epochs, epoch):
  169. """Compute a linearly varying number.
  170. Decrease linearly from start to end until epoch saturate_epochs.
  171. Args:
  172. start: the initial number.
  173. end: the end number.
  174. saturate_epochs: after this we do not reduce the number; if less than
  175. or equal to zero, just return start.
  176. epoch: the current learning epoch.
  177. Returns:
  178. the caculated number.
  179. """
  180. if saturate_epochs <= 0:
  181. return start
  182. step = (start - end) / (saturate_epochs - 1)
  183. if epoch < saturate_epochs:
  184. return start - step * epoch
  185. else:
  186. return end
  187. def BatchClipByL2norm(t, upper_bound, name=None):
  188. """Clip an array of tensors by L2 norm.
  189. Shrink each dimension-0 slice of tensor (for matrix it is each row) such
  190. that the l2 norm is at most upper_bound. Here we clip each row as it
  191. corresponds to each example in the batch.
  192. Args:
  193. t: the input tensor.
  194. upper_bound: the upperbound of the L2 norm.
  195. name: optional name.
  196. Returns:
  197. the clipped tensor.
  198. """
  199. assert upper_bound > 0
  200. with tf.name_scope(values=[t, upper_bound], name=name,
  201. default_name="batch_clip_by_l2norm") as name:
  202. saved_shape = tf.shape(t)
  203. batch_size = tf.slice(saved_shape, [0], [1])
  204. t2 = tf.reshape(t, tf.concat(axis=0, values=[batch_size, [-1]]))
  205. upper_bound_inv = tf.fill(tf.slice(saved_shape, [0], [1]),
  206. tf.constant(1.0/upper_bound))
  207. # Add a small number to avoid divide by 0
  208. l2norm_inv = tf.rsqrt(tf.reduce_sum(t2 * t2, [1]) + 0.000001)
  209. scale = tf.minimum(l2norm_inv, upper_bound_inv) * upper_bound
  210. clipped_t = tf.matmul(tf.diag(scale), t2)
  211. clipped_t = tf.reshape(clipped_t, saved_shape, name=name)
  212. return clipped_t
  213. def SoftThreshold(t, threshold_ratio, name=None):
  214. """Soft-threshold a tensor by the mean value.
  215. Softthreshold each dimension-0 vector (for matrix it is each column) by
  216. the mean of absolute value multiplied by the threshold_ratio factor. Here
  217. we soft threshold each column as it corresponds to each unit in a layer.
  218. Args:
  219. t: the input tensor.
  220. threshold_ratio: the threshold ratio.
  221. name: the optional name for the returned tensor.
  222. Returns:
  223. the thresholded tensor, where each entry is soft-thresholded by
  224. threshold_ratio times the mean of the aboslute value of each column.
  225. """
  226. assert threshold_ratio >= 0
  227. with tf.name_scope(values=[t, threshold_ratio], name=name,
  228. default_name="soft_thresholding") as name:
  229. saved_shape = tf.shape(t)
  230. t2 = tf.reshape(t, tf.concat(axis=0, values=[tf.slice(saved_shape, [0], [1]), -1]))
  231. t_abs = tf.abs(t2)
  232. t_x = tf.sign(t2) * tf.nn.relu(t_abs -
  233. (tf.reduce_mean(t_abs, [0],
  234. keep_dims=True) *
  235. threshold_ratio))
  236. return tf.reshape(t_x, saved_shape, name=name)
  237. def AddGaussianNoise(t, sigma, name=None):
  238. """Add i.i.d. Gaussian noise (0, sigma^2) to every entry of t.
  239. Args:
  240. t: the input tensor.
  241. sigma: the stddev of the Gaussian noise.
  242. name: optional name.
  243. Returns:
  244. the noisy tensor.
  245. """
  246. with tf.name_scope(values=[t, sigma], name=name,
  247. default_name="add_gaussian_noise") as name:
  248. noisy_t = t + tf.random_normal(tf.shape(t), stddev=sigma)
  249. return noisy_t
  250. def GenerateBinomialTable(m):
  251. """Generate binomial table.
  252. Args:
  253. m: the size of the table.
  254. Returns:
  255. A two dimensional array T where T[i][j] = (i choose j),
  256. for 0<= i, j <=m.
  257. """
  258. table = numpy.zeros((m + 1, m + 1), dtype=numpy.float64)
  259. for i in range(m + 1):
  260. table[i, 0] = 1
  261. for i in range(1, m + 1):
  262. for j in range(1, m + 1):
  263. v = table[i - 1, j] + table[i - 1, j -1]
  264. assert not math.isnan(v) and not math.isinf(v)
  265. table[i, j] = v
  266. return tf.convert_to_tensor(table)