shapes.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Shape manipulation functions.
  16. rotate_dimensions: prepares for a rotating transpose by returning a rotated
  17. list of dimension indices.
  18. transposing_reshape: allows a dimension to be factorized, with one of the pieces
  19. transferred to another dimension, or to transpose factors within a single
  20. dimension.
  21. tensor_dim: gets a shape dimension as a constant integer if known otherwise a
  22. runtime usable tensor value.
  23. tensor_shape: returns the full shape of a tensor as the tensor_dim.
  24. """
  25. import tensorflow as tf
  26. def rotate_dimensions(num_dims, src_dim, dest_dim):
  27. """Returns a list of dimension indices that will rotate src_dim to dest_dim.
  28. src_dim is moved to dest_dim, with all intervening dimensions shifted towards
  29. the hole left by src_dim. Eg:
  30. num_dims = 4, src_dim=3, dest_dim=1
  31. Returned list=[0, 3, 1, 2]
  32. For a tensor with dims=[5, 4, 3, 2] a transpose would yield [5, 2, 4, 3].
  33. Args:
  34. num_dims: The number of dimensions to handle.
  35. src_dim: The dimension to move.
  36. dest_dim: The dimension to move src_dim to.
  37. Returns:
  38. A list of rotated dimension indices.
  39. """
  40. # List of dimensions for transpose.
  41. dim_list = range(num_dims)
  42. # Shuffle src_dim to dest_dim by swapping to shuffle up the other dims.
  43. step = 1 if dest_dim > src_dim else -1
  44. for x in xrange(src_dim, dest_dim, step):
  45. dim_list[x], dim_list[x + step] = dim_list[x + step], dim_list[x]
  46. return dim_list
  47. def transposing_reshape(tensor,
  48. src_dim,
  49. part_a,
  50. part_b,
  51. dest_dim_a,
  52. dest_dim_b,
  53. name=None):
  54. """Splits src_dim and sends one of the pieces to another dim.
  55. Terminology:
  56. A matrix is often described as 'row-major' or 'column-major', which doesn't
  57. help if you can't remember which is the row index and which is the column,
  58. even if you know what 'major' means, so here is a simpler explanation of it:
  59. When TF stores a tensor of size [d0, d1, d2, d3] indexed by [i0, i1, i2, i3],
  60. the memory address of an element is calculated using:
  61. ((i0 * d1 + i1) * d2 + i2) * d3 + i3, so, d0 is the MOST SIGNIFICANT dimension
  62. and d3 the LEAST SIGNIFICANT, just like in the decimal number 1234, 1 is the
  63. most significant digit and 4 the least significant. In both cases the most
  64. significant is multiplied by the largest number to determine its 'value'.
  65. Furthermore, if we reshape the tensor to [d0'=d0, d1'=d1 x d2, d2'=d3], then
  66. the MOST SIGNIFICANT part of d1' is d1 and the LEAST SIGNIFICANT part of d1'
  67. is d2.
  68. Action:
  69. transposing_reshape splits src_dim into factors [part_a, part_b], and sends
  70. the most significant part (of size part_a) to be the most significant part of
  71. dest_dim_a*(Exception: see NOTE 2), and the least significant part (of size
  72. part_b) to be the most significant part of dest_dim_b.
  73. This is basically a combination of reshape, rotating transpose, reshape.
  74. NOTE1: At least one of dest_dim_a and dest_dim_b must equal src_dim, ie one of
  75. the parts always stays put, so src_dim is never totally destroyed and the
  76. output number of dimensions is always the same as the input.
  77. NOTE2: If dest_dim_a == dest_dim_b == src_dim, then parts a and b are simply
  78. transposed within src_dim to become part_b x part_a, so the most significant
  79. part becomes the least significant part and vice versa. Thus if you really
  80. wanted to make one of the parts the least significant side of the destiantion,
  81. the destination dimension can be internally transposed with a second call to
  82. transposing_reshape.
  83. NOTE3: One of part_a and part_b may be -1 to allow src_dim to be of unknown
  84. size with one known-size factor. Otherwise part_a * part_b must equal the size
  85. of src_dim.
  86. NOTE4: The reshape preserves as many known-at-graph-build-time dimension sizes
  87. as are available.
  88. Example:
  89. Input dims=[5, 2, 6, 2]
  90. tensor=[[[[0, 1][2, 3][4, 5][6, 7][8, 9][10, 11]]
  91. [[12, 13][14, 15][16, 17][18, 19][20, 21][22, 23]]
  92. [[[24, 25]...
  93. src_dim=2, part_a=2, part_b=3, dest_dim_a=3, dest_dim_b=2
  94. output dims =[5, 2, 3, 4]
  95. output tensor=[[[[0, 1, 6, 7][2, 3, 8, 9][4, 5, 10, 11]]
  96. [[12, 13, 18, 19][14, 15, 20, 21][16, 17, 22, 23]]]
  97. [[[24, 26, 28]...
  98. Example2:
  99. Input dims=[phrases, words, letters]=[2, 6, x]
  100. tensor=[[[the][cat][sat][on][the][mat]]
  101. [[a][stitch][in][time][saves][nine]]]
  102. We can factorize the 6 words into 3x2 = [[the][cat]][[sat][on]][[the][mat]]
  103. or 2x3=[[the][cat][sat]][[on][the][mat]] and
  104. src_dim=1, part_a=3, part_b=2, dest_dim_a=1, dest_dim_b=1
  105. would yield:
  106. [[[the][sat][the][cat][on][mat]]
  107. [[a][in][saves][stitch][time][nine]]], but
  108. src_dim=1, part_a=2, part_b=3, dest_dim_a=1, dest_dim_b=1
  109. would yield:
  110. [[[the][on][cat][the][sat][mat]]
  111. [[a][time][stitch][saves][in][nine]]], and
  112. src_dim=1, part_a=2, part_b=3, dest_dim_a=0, dest_dim_b=1
  113. would yield:
  114. [[[the][cat][sat]]
  115. [[a][stitch][in]]
  116. [[on][the][mat]]
  117. [[time][saves][nine]]]
  118. Now remember that the words above represent any least-significant subset of
  119. the input dimensions.
  120. Args:
  121. tensor: A tensor to reshape.
  122. src_dim: The dimension to split.
  123. part_a: The first factor of the split.
  124. part_b: The second factor of the split.
  125. dest_dim_a: The dimension to move part_a of src_dim to.
  126. dest_dim_b: The dimension to move part_b of src_dim to.
  127. name: Optional base name for all the ops.
  128. Returns:
  129. Reshaped tensor.
  130. Raises:
  131. ValueError: If the args are invalid.
  132. """
  133. if dest_dim_a != src_dim and dest_dim_b != src_dim:
  134. raise ValueError(
  135. 'At least one of dest_dim_a, dest_dim_b must equal src_dim!')
  136. if part_a == 0 or part_b == 0:
  137. raise ValueError('Zero not allowed for part_a or part_b!')
  138. if part_a < 0 and part_b < 0:
  139. raise ValueError('At least one of part_a and part_b must be positive!')
  140. if not name:
  141. name = 'transposing_reshape'
  142. prev_shape = tensor_shape(tensor)
  143. expanded = tf.reshape(
  144. tensor,
  145. prev_shape[:src_dim] + [part_a, part_b] + prev_shape[src_dim + 1:],
  146. name=name + '_reshape_in')
  147. dest = dest_dim_b
  148. if dest_dim_a != src_dim:
  149. # We are just moving part_a to dest_dim_a.
  150. dest = dest_dim_a
  151. else:
  152. # We are moving part_b to dest_dim_b.
  153. src_dim += 1
  154. dim_list = rotate_dimensions(len(expanded.get_shape()), src_dim, dest)
  155. expanded = tf.transpose(expanded, dim_list, name=name + '_rot_transpose')
  156. # Reshape identity except dest,dest+1, which get merged.
  157. ex_shape = tensor_shape(expanded)
  158. combined = ex_shape[dest] * ex_shape[dest + 1]
  159. return tf.reshape(
  160. expanded,
  161. ex_shape[:dest] + [combined] + ex_shape[dest + 2:],
  162. name=name + '_reshape_out')
  163. def tensor_dim(tensor, dim):
  164. """Returns int dimension if known at a graph build time else a tensor.
  165. If the size of the dim of tensor is known at graph building time, then that
  166. known value is returned, otherwise (instead of None), a Tensor that will give
  167. the size of the dimension when the graph is run. The return value will be
  168. accepted by tf.reshape in multiple (or even all) dimensions, even when the
  169. sizes are not known at graph building time, unlike -1, which can only be used
  170. in one dimension. It is a bad idea to use tf.shape all the time, as some ops
  171. demand a known (at graph build time) size. This function therefore returns
  172. the best available, most useful dimension size.
  173. Args:
  174. tensor: Input tensor.
  175. dim: Dimension to find the size of.
  176. Returns:
  177. An integer if shape is known at build time, otherwise a tensor of int32.
  178. """
  179. result = tensor.get_shape().as_list()[dim]
  180. if result is None:
  181. result = tf.shape(tensor)[dim]
  182. return result
  183. def tensor_shape(tensor):
  184. """Returns a heterogeneous list of tensor_dim for the tensor.
  185. See tensor_dim for a more detailed explanation.
  186. Args:
  187. tensor: Input tensor.
  188. Returns:
  189. A heterogeneous list of integers and int32 tensors.
  190. """
  191. result = []
  192. for d in xrange(len(tensor.get_shape())):
  193. result.append(tensor_dim(tensor, d))
  194. return result