tf_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
  16. import tensorflow as tf
  17. import numpy as np
  18. def conv2d(x, n_filters,
  19. k_h=5, k_w=5,
  20. stride_h=2, stride_w=2,
  21. stddev=0.02,
  22. activation=lambda x: x,
  23. bias=True,
  24. padding='SAME',
  25. name="Conv2D"):
  26. """2D Convolution with options for kernel size, stride, and init deviation.
  27. Parameters
  28. ----------
  29. x : Tensor
  30. Input tensor to convolve.
  31. n_filters : int
  32. Number of filters to apply.
  33. k_h : int, optional
  34. Kernel height.
  35. k_w : int, optional
  36. Kernel width.
  37. stride_h : int, optional
  38. Stride in rows.
  39. stride_w : int, optional
  40. Stride in cols.
  41. stddev : float, optional
  42. Initialization's standard deviation.
  43. activation : arguments, optional
  44. Function which applies a nonlinearity
  45. padding : str, optional
  46. 'SAME' or 'VALID'
  47. name : str, optional
  48. Variable scope to use.
  49. Returns
  50. -------
  51. x : Tensor
  52. Convolved input.
  53. """
  54. with tf.variable_scope(name):
  55. w = tf.get_variable(
  56. 'w', [k_h, k_w, x.get_shape()[-1], n_filters],
  57. initializer=tf.truncated_normal_initializer(stddev=stddev))
  58. conv = tf.nn.conv2d(
  59. x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
  60. if bias:
  61. b = tf.get_variable(
  62. 'b', [n_filters],
  63. initializer=tf.truncated_normal_initializer(stddev=stddev))
  64. conv = conv + b
  65. return conv
  66. def linear(x, n_units, scope=None, stddev=0.02,
  67. activation=lambda x: x):
  68. """Fully-connected network.
  69. Parameters
  70. ----------
  71. x : Tensor
  72. Input tensor to the network.
  73. n_units : int
  74. Number of units to connect to.
  75. scope : str, optional
  76. Variable scope to use.
  77. stddev : float, optional
  78. Initialization's standard deviation.
  79. activation : arguments, optional
  80. Function which applies a nonlinearity
  81. Returns
  82. -------
  83. x : Tensor
  84. Fully-connected output.
  85. """
  86. shape = x.get_shape().as_list()
  87. with tf.variable_scope(scope or "Linear"):
  88. matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
  89. tf.random_normal_initializer(stddev=stddev))
  90. return activation(tf.matmul(x, matrix))
  91. # %%
  92. def weight_variable(shape):
  93. '''Helper function to create a weight variable initialized with
  94. a normal distribution
  95. Parameters
  96. ----------
  97. shape : list
  98. Size of weight variable
  99. '''
  100. #initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
  101. initial = tf.zeros(shape)
  102. return tf.Variable(initial)
  103. # %%
  104. def bias_variable(shape):
  105. '''Helper function to create a bias variable initialized with
  106. a constant value.
  107. Parameters
  108. ----------
  109. shape : list
  110. Size of weight variable
  111. '''
  112. initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
  113. return tf.Variable(initial)
  114. # %%
  115. def dense_to_one_hot(labels, n_classes=2):
  116. """Convert class labels from scalars to one-hot vectors."""
  117. labels = np.array(labels)
  118. n_labels = labels.shape[0]
  119. index_offset = np.arange(n_labels) * n_classes
  120. labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
  121. labels_one_hot.flat[index_offset + labels.ravel()] = 1
  122. return labels_one_hot