network_units.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Basic network units used in assembling DRAGNN graphs."""
  16. from abc import ABCMeta
  17. from abc import abstractmethod
  18. import tensorflow as tf
  19. from tensorflow.python.ops import nn
  20. from tensorflow.python.ops import tensor_array_ops as ta
  21. from tensorflow.python.platform import tf_logging as logging
  22. from dragnn.python import dragnn_ops
  23. from syntaxnet.util import check
  24. from syntaxnet.util import registry
  25. def linked_embeddings_name(channel_id):
  26. """Returns the name of the linked embedding matrix for some channel ID."""
  27. return 'linked_embedding_matrix_%d' % channel_id
  28. def fixed_embeddings_name(channel_id):
  29. """Returns the name of the fixed embedding matrix for some channel ID."""
  30. return 'fixed_embedding_matrix_%d' % channel_id
  31. class StoredActivations(object):
  32. """Wrapper around stored activation vectors.
  33. Because activations are produced and consumed in different layouts by bulk
  34. vs. dynamic components, this class provides a simple common
  35. interface/conversion API. It can be constructed from either a TensorArray
  36. (dynamic) or a Tensor (bulk), and the resulting object to use for lookups is
  37. either bulk_tensor (for bulk components) or dynamic_tensor (for dynamic
  38. components).
  39. """
  40. def __init__(self, tensor=None, array=None, stride=None, dim=None):
  41. """Creates ops for converting the input to either format.
  42. If 'tensor' is used, then a conversion from [stride * steps, dim] to
  43. [steps + 1, stride, dim] is performed for dynamic_tensor reads.
  44. If 'array' is used, then a conversion from [steps + 1, stride, dim] to
  45. [stride * steps, dim] is performed for bulk_tensor reads.
  46. Args:
  47. tensor: Bulk tensor input.
  48. array: TensorArray dynamic input.
  49. stride: stride of bulk tensor. Not used for dynamic.
  50. dim: dim of bulk tensor. Not used for dynamic.
  51. """
  52. if tensor is not None:
  53. check.IsNone(array, 'Cannot initialize from tensor and array')
  54. check.NotNone(stride, 'Stride is required for bulk tensor')
  55. check.NotNone(dim, 'Dim is required for bulk tensor')
  56. self._bulk_tensor = tensor
  57. with tf.name_scope('convert_to_dyn'):
  58. tensor = tf.reshape(tensor, [stride, -1, dim])
  59. tensor = tf.transpose(tensor, perm=[1, 0, 2])
  60. pad = tf.zeros([1, stride, dim], dtype=tensor.dtype)
  61. self._array_tensor = tf.concat([pad, tensor], 0)
  62. if array is not None:
  63. check.IsNone(tensor, 'Cannot initialize from both tensor and array')
  64. with tf.name_scope('convert_to_bulk'):
  65. self._bulk_tensor = convert_network_state_tensorarray(array)
  66. with tf.name_scope('convert_to_dyn'):
  67. self._array_tensor = array.stack()
  68. @property
  69. def bulk_tensor(self):
  70. return self._bulk_tensor
  71. @property
  72. def dynamic_tensor(self):
  73. return self._array_tensor
  74. class NamedTensor(object):
  75. """Container for a tensor with associated name and dimension attributes."""
  76. def __init__(self, tensor, name, dim=None):
  77. """Inits NamedTensor with tensor, name and optional dim."""
  78. self.tensor = tensor
  79. self.name = name
  80. self.dim = dim
  81. def add_embeddings(channel_id, feature_spec, seed=None):
  82. """Adds a variable for the embedding of a given fixed feature.
  83. Supports pre-trained or randomly initialized embeddings In both cases, extra
  84. vector is reserved for out-of-vocabulary words, so the embedding matrix has
  85. the size of [feature_spec.vocabulary_size + 1, feature_spec.embedding_dim].
  86. Args:
  87. channel_id: Numeric id of the fixed feature channel
  88. feature_spec: Feature spec protobuf of type FixedFeatureChannel
  89. seed: used for random initializer
  90. Returns:
  91. tf.Variable object corresponding to the embedding for that feature.
  92. Raises:
  93. RuntimeError: if more the pretrained embeddings are specified in resources
  94. containing more than one part.
  95. """
  96. check.Gt(feature_spec.embedding_dim, 0,
  97. 'Embeddings requested for non-embedded feature: %s' % feature_spec)
  98. name = fixed_embeddings_name(channel_id)
  99. shape = [feature_spec.vocabulary_size + 1, feature_spec.embedding_dim]
  100. if feature_spec.HasField('pretrained_embedding_matrix'):
  101. if len(feature_spec.pretrained_embedding_matrix.part) > 1:
  102. raise RuntimeError('pretrained_embedding_matrix resource contains '
  103. 'more than one part:\n%s',
  104. str(feature_spec.pretrained_embedding_matrix))
  105. if len(feature_spec.vocab.part) > 1:
  106. raise RuntimeError('vocab resource contains more than one part:\n%s',
  107. str(feature_spec.vocab))
  108. seed1, seed2 = tf.get_seed(seed)
  109. embeddings = dragnn_ops.dragnn_embedding_initializer(
  110. embedding_input=feature_spec.pretrained_embedding_matrix.part[0]
  111. .file_pattern,
  112. vocab=feature_spec.vocab.part[0].file_pattern,
  113. scaling_coefficient=1.0,
  114. seed=seed1,
  115. seed2=seed2)
  116. return tf.get_variable(name, initializer=tf.reshape(embeddings, shape))
  117. else:
  118. return tf.get_variable(
  119. name,
  120. shape,
  121. initializer=tf.random_normal_initializer(
  122. stddev=1.0 / feature_spec.embedding_dim**.5, seed=seed))
  123. def embedding_lookup(embedding_matrix, indices, ids, weights, size):
  124. """Performs a weighted embedding lookup.
  125. Args:
  126. embedding_matrix: float Tensor from which to do the lookup.
  127. indices: int Tensor for the output rows of the looked up vectors.
  128. ids: int Tensor vectors to look up in the embedding_matrix.
  129. weights: float Tensor weights to apply to the looked up vectors.
  130. size: int number of output rows. Needed since some output rows may be
  131. empty.
  132. Returns:
  133. Weighted embedding vectors.
  134. """
  135. embeddings = tf.nn.embedding_lookup([embedding_matrix], ids)
  136. # TODO(googleuser): allow skipping weights.
  137. broadcast_weights_shape = tf.concat([tf.shape(weights), [1]], 0)
  138. embeddings *= tf.reshape(weights, broadcast_weights_shape)
  139. embeddings = tf.unsorted_segment_sum(embeddings, indices, size)
  140. return embeddings
  141. def fixed_feature_lookup(component, state, channel_id, stride):
  142. """Looks up fixed features and passes them through embeddings.
  143. Embedding vectors may be scaled by weights if the features specify it.
  144. Args:
  145. component: Component object in which to look up the fixed features.
  146. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  147. channel_id: int id of the fixed feature to look up.
  148. stride: int Tensor of current batch * beam size.
  149. Returns:
  150. NamedTensor object containing the embedding vectors.
  151. """
  152. feature_spec = component.spec.fixed_feature[channel_id]
  153. check.Gt(feature_spec.embedding_dim, 0,
  154. 'Embeddings requested for non-embedded feature: %s' % feature_spec)
  155. embedding_matrix = component.get_variable(fixed_embeddings_name(channel_id))
  156. with tf.op_scope([embedding_matrix], 'fixed_embedding_' + feature_spec.name):
  157. indices, ids, weights = dragnn_ops.extract_fixed_features(
  158. state.handle, component=component.name, channel_id=channel_id)
  159. size = stride * feature_spec.size
  160. embeddings = embedding_lookup(embedding_matrix, indices, ids, weights, size)
  161. dim = feature_spec.size * feature_spec.embedding_dim
  162. return NamedTensor(
  163. tf.reshape(embeddings, [-1, dim]), feature_spec.name, dim=dim)
  164. def get_input_tensor(fixed_embeddings, linked_embeddings):
  165. """Helper function for constructing an input tensor from all the features.
  166. Args:
  167. fixed_embeddings: list of NamedTensor objects for fixed feature channels
  168. linked_embeddings: list of NamedTensor objects for linked feature channels
  169. Returns:
  170. a tensor of shape [N, D], where D is the total input dimension of the
  171. concatenated feature channels
  172. Raises:
  173. RuntimeError: if no features, fixed or linked, are configured.
  174. """
  175. embeddings = fixed_embeddings + linked_embeddings
  176. if not embeddings:
  177. raise RuntimeError('There needs to be at least one feature set defined.')
  178. # Concat_v2 takes care of optimizing away the concatenation
  179. # operation in the case when there is exactly one embedding input.
  180. return tf.concat([e.tensor for e in embeddings], 1)
  181. def get_input_tensor_with_stride(fixed_embeddings, linked_embeddings, stride):
  182. """Constructs an input tensor with a separate dimension for steps.
  183. Args:
  184. fixed_embeddings: list of NamedTensor objects for fixed feature channels
  185. linked_embeddings: list of NamedTensor objects for linked feature channels
  186. stride: int stride (i.e. beam * batch) to use to reshape the input
  187. Returns:
  188. a tensor of shape [stride, num_steps, D], where D is the total input
  189. dimension of the concatenated feature channels
  190. """
  191. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  192. shape = tf.shape(input_tensor)
  193. return tf.reshape(input_tensor, [stride, -1, shape[1]])
  194. def convert_network_state_tensorarray(tensorarray):
  195. """Converts a source TensorArray to a source Tensor.
  196. Performs a permutation between the steps * [stride, D] shape of a
  197. source TensorArray and the (flattened) [stride * steps, D] shape of
  198. a source Tensor.
  199. The TensorArrays used during recurrence have an additional zeroth step that
  200. needs to be removed.
  201. Args:
  202. tensorarray: TensorArray object to be converted.
  203. Returns:
  204. Tensor object after conversion.
  205. """
  206. tensor = tensorarray.stack() # Results in a [steps, stride, D] tensor.
  207. tensor = tf.slice(tensor, [1, 0, 0], [-1, -1, -1]) # Lop off the 0th step.
  208. tensor = tf.transpose(tensor, [1, 0, 2]) # Switch steps and stride.
  209. return tf.reshape(tensor, [-1, tf.shape(tensor)[2]])
  210. def pass_through_embedding_matrix(act_block, embedding_matrix, step_idx):
  211. """Passes the activations through the embedding_matrix.
  212. Takes care to handle out of bounds lookups.
  213. Args:
  214. act_block: matrix of activations.
  215. embedding_matrix: matrix of weights.
  216. step_idx: vector containing step indices, with -1 indicating out of bounds.
  217. Returns:
  218. the embedded activations.
  219. """
  220. # Indicator vector for out of bounds lookups.
  221. step_idx_mask = tf.expand_dims(tf.equal(step_idx, -1), -1)
  222. # Pad the last column of the activation vectors with the indicator.
  223. act_block = tf.concat([act_block, tf.to_float(step_idx_mask)], 1)
  224. return tf.matmul(act_block, embedding_matrix)
  225. def lookup_named_tensor(name, named_tensors):
  226. """Retrieves a NamedTensor by name.
  227. Args:
  228. name: Name of the tensor to retrieve.
  229. named_tensors: List of NamedTensor objects to search.
  230. Returns:
  231. The NamedTensor in |named_tensors| with the |name|.
  232. Raises:
  233. KeyError: If the |name| is not found among the |named_tensors|.
  234. """
  235. for named_tensor in named_tensors:
  236. if named_tensor.name == name:
  237. return named_tensor
  238. raise KeyError('Name "%s" not found in named tensors: %s' %
  239. (name, named_tensors))
  240. def activation_lookup_recurrent(component, state, channel_id, source_array,
  241. source_layer_size, stride):
  242. """Looks up activations from tensor arrays.
  243. If the linked feature's embedding_dim is set to -1, the feature vectors are
  244. not passed through (i.e. multiplied by) an embedding matrix.
  245. Args:
  246. component: Component object in which to look up the fixed features.
  247. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  248. channel_id: int id of the fixed feature to look up.
  249. source_array: TensorArray from which to fetch feature vectors, expected to
  250. have size [steps + 1] elements of shape [stride, D] each.
  251. source_layer_size: int length of feature vectors before embedding.
  252. stride: int Tensor of current batch * beam size.
  253. Returns:
  254. NamedTensor object containing the embedding vectors.
  255. """
  256. feature_spec = component.spec.linked_feature[channel_id]
  257. with tf.name_scope('activation_lookup_recurrent_%s' % feature_spec.name):
  258. # Linked features are returned as a pair of tensors, one indexing into
  259. # steps, and one indexing within the activation tensor (beam x batch)
  260. # stored for a step.
  261. step_idx, idx = dragnn_ops.extract_link_features(
  262. state.handle, component=component.name, channel_id=channel_id)
  263. # We take the [steps, batch*beam, ...] tensor array, gather and concat
  264. # the steps we might need into a [some_steps*batch*beam, ...] tensor,
  265. # and flatten 'idx' to dereference this new tensor.
  266. #
  267. # The first element of each tensor array is reserved for an
  268. # initialization variable, so we offset all step indices by +1.
  269. #
  270. # TODO(googleuser): It would be great to not have to extract
  271. # the steps in their entirety, forcing a copy of much of the
  272. # TensorArray at each step. Better would be to support a
  273. # TensorArray.gather_nd to pick the specific elements directly.
  274. # TODO(googleuser): In the interim, a small optimization would
  275. # be to use tf.unique instead of tf.range.
  276. step_min = tf.reduce_min(step_idx)
  277. ta_range = tf.range(step_min + 1, tf.reduce_max(step_idx) + 2)
  278. act_block = source_array.gather(ta_range)
  279. act_block = tf.reshape(act_block,
  280. tf.concat([[-1], tf.shape(act_block)[2:]], 0))
  281. flat_idx = (step_idx - step_min) * stride + idx
  282. act_block = tf.gather(act_block, flat_idx)
  283. act_block = tf.reshape(act_block, [-1, source_layer_size])
  284. if feature_spec.embedding_dim != -1:
  285. embedding_matrix = component.get_variable(
  286. linked_embeddings_name(channel_id))
  287. act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
  288. step_idx)
  289. dim = feature_spec.size * feature_spec.embedding_dim
  290. else:
  291. # If embedding_dim is -1, just output concatenation of activations.
  292. dim = feature_spec.size * source_layer_size
  293. return NamedTensor(
  294. tf.reshape(act_block, [-1, dim]), feature_spec.name, dim=dim)
  295. def activation_lookup_other(component, state, channel_id, source_tensor,
  296. source_layer_size):
  297. """Looks up activations from tensors.
  298. If the linked feature's embedding_dim is set to -1, the feature vectors are
  299. not passed through (i.e. multiplied by) an embedding matrix.
  300. Args:
  301. component: Component object in which to look up the fixed features.
  302. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  303. channel_id: int id of the fixed feature to look up.
  304. source_tensor: Tensor from which to fetch feature vectors. Expected to have
  305. have shape [steps + 1, stride, D].
  306. source_layer_size: int length of feature vectors before embedding (D). It
  307. would in principle be possible to get this dimension dynamically from
  308. the second dimension of source_tensor. However, having it statically is
  309. more convenient.
  310. Returns:
  311. NamedTensor object containing the embedding vectors.
  312. """
  313. feature_spec = component.spec.linked_feature[channel_id]
  314. with tf.name_scope('activation_lookup_other_%s' % feature_spec.name):
  315. # Linked features are returned as a pair of tensors, one indexing into
  316. # steps, and one indexing within the stride (beam x batch) of each step.
  317. step_idx, idx = dragnn_ops.extract_link_features(
  318. state.handle, component=component.name, channel_id=channel_id)
  319. # The first element of each tensor array is reserved for an
  320. # initialization variable, so we offset all step indices by +1.
  321. indices = tf.stack([step_idx + 1, idx], axis=1)
  322. act_block = tf.gather_nd(source_tensor, indices)
  323. act_block = tf.reshape(act_block, [-1, source_layer_size])
  324. if feature_spec.embedding_dim != -1:
  325. embedding_matrix = component.get_variable(
  326. linked_embeddings_name(channel_id))
  327. act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
  328. step_idx)
  329. dim = feature_spec.size * feature_spec.embedding_dim
  330. else:
  331. # If embedding_dim is -1, just output concatenation of activations.
  332. dim = feature_spec.size * source_layer_size
  333. return NamedTensor(
  334. tf.reshape(act_block, [-1, dim]), feature_spec.name, dim=dim)
  335. class LayerNorm(object):
  336. """Utility to add layer normalization to any tensor.
  337. Layer normalization implementation is based on:
  338. https://arxiv.org/abs/1607.06450. "Layer Normalization"
  339. Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
  340. This object will construct additional variables that need to be optimized, and
  341. these variables can be accessed via params().
  342. Attributes:
  343. params: List of additional parameters to be trained.
  344. """
  345. def __init__(self, component, name, shape, dtype):
  346. """Construct variables to normalize an input of given shape.
  347. Arguments:
  348. component: ComponentBuilder handle.
  349. name: Human readable name to organize the variables.
  350. shape: Shape of the layer to be normalized.
  351. dtype: Type of the layer to be normalized.
  352. """
  353. self._name = name
  354. self._shape = shape
  355. self._component = component
  356. beta = tf.get_variable(
  357. 'beta_%s' % name,
  358. shape=shape,
  359. dtype=dtype,
  360. initializer=tf.zeros_initializer())
  361. gamma = tf.get_variable(
  362. 'gamma_%s' % name,
  363. shape=shape,
  364. dtype=dtype,
  365. initializer=tf.ones_initializer())
  366. self._params = [beta, gamma]
  367. @property
  368. def params(self):
  369. return self._params
  370. def normalize(self, inputs):
  371. """Apply normalization to input.
  372. The shape must match the declared shape in the constructor.
  373. [This is copied from tf.contrib.rnn.LayerNormBasicLSTMCell.]
  374. Args:
  375. inputs: Input tensor
  376. Returns:
  377. Normalized version of input tensor.
  378. Raises:
  379. ValueError: if inputs has undefined rank.
  380. """
  381. inputs_shape = inputs.get_shape()
  382. inputs_rank = inputs_shape.ndims
  383. if inputs_rank is None:
  384. raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  385. axis = range(1, inputs_rank)
  386. beta = self._component.get_variable('beta_%s' % self._name)
  387. gamma = self._component.get_variable('gamma_%s' % self._name)
  388. with tf.variable_scope('layer_norm_%s' % self._name):
  389. # Calculate the moments on the last axis (layer activations).
  390. mean, variance = nn.moments(inputs, axis, keep_dims=True)
  391. # Compute layer normalization using the batch_normalization function.
  392. variance_epsilon = 1E-12
  393. outputs = nn.batch_normalization(
  394. inputs, mean, variance, beta, gamma, variance_epsilon)
  395. outputs.set_shape(inputs_shape)
  396. return outputs
  397. class Layer(object):
  398. """A layer in a feed-forward network.
  399. Attributes:
  400. component: ComponentBuilderBase that produces this layer.
  401. name: Name of this layer.
  402. dim: Dimension of this layer, or negative if dynamic.
  403. """
  404. def __init__(self, component, name, dim):
  405. check.NotNone(dim, 'Dimension is required')
  406. self.component = component
  407. self.name = name
  408. self.dim = dim
  409. def __str__(self):
  410. return 'Layer: %s/%s[%d]' % (self.component.name, self.name, self.dim)
  411. def create_array(self, stride):
  412. """Creates a new tensor array to store this layer's activations.
  413. Arguments:
  414. stride: Possibly dynamic batch * beam size with which to initialize the
  415. tensor array
  416. Returns:
  417. TensorArray object
  418. """
  419. check.Gt(self.dim, 0, 'Cannot create array when dimension is dynamic')
  420. tensor_array = ta.TensorArray(dtype=tf.float32,
  421. size=0,
  422. dynamic_size=True,
  423. clear_after_read=False,
  424. infer_shape=False,
  425. name='%s_array' % self.name)
  426. # Start each array with all zeros. Special values will still be learned via
  427. # the extra embedding dimension stored for each linked feature channel.
  428. initial_value = tf.zeros([stride, self.dim])
  429. return tensor_array.write(0, initial_value)
  430. def get_attrs_with_defaults(parameters, defaults):
  431. """Populates a dictionary with run-time attributes.
  432. Given defaults, populates any overrides from 'parameters' with their
  433. corresponding converted values. 'defaults' should be typed. This is useful
  434. for specifying NetworkUnit-specific configuration options.
  435. Args:
  436. parameters: a <string, string> map.
  437. defaults: a <string, value> typed set of default values.
  438. Returns:
  439. dictionary populated with any overrides.
  440. Raises:
  441. RuntimeError: if a key in parameters is not present in defaults.
  442. """
  443. attrs = defaults
  444. for key, value in parameters.iteritems():
  445. check.In(key, defaults, 'Unknown attribute: %s' % key)
  446. if isinstance(defaults[key], bool):
  447. attrs[key] = value.lower() == 'true'
  448. else:
  449. attrs[key] = type(defaults[key])(value)
  450. return attrs
  451. def maybe_apply_dropout(inputs, keep_prob, per_sequence, stride=None):
  452. """Applies dropout, if so configured, to an input tensor.
  453. The input may be rank 2 or 3 depending on whether the stride (i.e., batch
  454. size) has been incorporated into the shape.
  455. Args:
  456. inputs: [stride * num_steps, dim] or [stride, num_steps, dim] input tensor.
  457. keep_prob: Scalar probability of keeping each input element. If >= 1.0, no
  458. dropout is performed.
  459. per_sequence: If true, sample the dropout mask once per sequence, instead of
  460. once per step. Requires |stride| when true.
  461. stride: Scalar batch size. Optional if |per_sequence| is false.
  462. Returns:
  463. [stride * num_steps, dim] or [stride, num_steps, dim] tensor, matching the
  464. shape of |inputs|, containing the masked or original inputs, depending on
  465. whether dropout was actually performed.
  466. """
  467. check.Ge(inputs.get_shape().ndims, 2, 'inputs must be rank 2 or 3')
  468. check.Le(inputs.get_shape().ndims, 3, 'inputs must be rank 2 or 3')
  469. flat = (inputs.get_shape().ndims == 2)
  470. if keep_prob >= 1.0:
  471. return inputs
  472. if not per_sequence:
  473. return tf.nn.dropout(inputs, keep_prob)
  474. check.NotNone(stride, 'per-sequence dropout requires stride')
  475. dim = inputs.get_shape().as_list()[-1]
  476. check.NotNone(dim, 'inputs must have static activation dimension, but have '
  477. 'static shape %s' % inputs.get_shape().as_list())
  478. # If needed, restore the batch dimension to separate the sequences.
  479. inputs_sxnxd = tf.reshape(inputs, [stride, -1, dim]) if flat else inputs
  480. # Replace |num_steps| with 1 in |noise_shape|, so the dropout mask broadcasts
  481. # to all steps for a particular sequence.
  482. noise_shape = [stride, 1, dim]
  483. masked_sxnxd = tf.nn.dropout(inputs_sxnxd, keep_prob, noise_shape)
  484. # If needed, flatten out the batch dimension in the return value.
  485. return tf.reshape(masked_sxnxd, [-1, dim]) if flat else masked_sxnxd
  486. @registry.RegisteredClass
  487. class NetworkUnitInterface(object):
  488. """Base class to implement NN specifications.
  489. This class contains the required functionality to build a network inside of a
  490. DRAGNN graph: (1) initializing TF variables during __init__(), and (2)
  491. creating particular instances from extracted features in create().
  492. Attributes:
  493. params (list): List of tf.Variable objects representing trainable
  494. parameters.
  495. layers (list): List of Layer objects to track network layers that should
  496. be written to Tensors during training and inference.
  497. """
  498. __metaclass__ = ABCMeta # required for @abstractmethod
  499. def __init__(self, component, init_layers=None, init_context_layers=None):
  500. """Initializes parameters for embedding matrices.
  501. The subclass may provide optional lists of initial layers and context layers
  502. to allow this base class constructor to use accessors like get_layer_size(),
  503. which is required for networks that may be used self-recurrently.
  504. Args:
  505. component: parent ComponentBuilderBase object.
  506. init_layers: optional initial layers.
  507. init_context_layers: optional initial context layers.
  508. """
  509. self._component = component
  510. self._params = []
  511. self._layers = init_layers if init_layers else []
  512. self._regularized_weights = []
  513. self._context_layers = init_context_layers if init_context_layers else []
  514. self._fixed_feature_dims = {} # mapping from name to dimension
  515. self._linked_feature_dims = {} # mapping from name to dimension
  516. # Allocate parameters for all embedding channels. Note that for both Fixed
  517. # and Linked embedding matrices, we store an additional +1 embedding that's
  518. # used when the index is out of scope.
  519. for channel_id, spec in enumerate(component.spec.fixed_feature):
  520. check.NotIn(spec.name, self._fixed_feature_dims,
  521. 'Duplicate fixed feature')
  522. check.Gt(spec.size, 0, 'Invalid fixed feature size')
  523. if spec.embedding_dim > 0:
  524. fixed_dim = spec.embedding_dim
  525. self._params.append(add_embeddings(channel_id, spec))
  526. else:
  527. fixed_dim = 1 # assume feature ID extraction; only one ID per step
  528. self._fixed_feature_dims[spec.name] = spec.size * fixed_dim
  529. for channel_id, spec in enumerate(component.spec.linked_feature):
  530. check.NotIn(spec.name, self._linked_feature_dims,
  531. 'Duplicate linked feature')
  532. check.Gt(spec.size, 0, 'Invalid linked feature size')
  533. if spec.source_component == component.name:
  534. source_array_dim = self.get_layer_size(spec.source_layer)
  535. else:
  536. source = component.master.lookup_component[spec.source_component]
  537. source_array_dim = source.network.get_layer_size(spec.source_layer)
  538. if spec.embedding_dim != -1:
  539. check.Gt(source_array_dim, 0,
  540. 'Cannot embed linked feature with dynamic dimension')
  541. self._params.append(
  542. tf.get_variable(
  543. linked_embeddings_name(channel_id),
  544. [source_array_dim + 1, spec.embedding_dim],
  545. initializer=tf.random_normal_initializer(
  546. stddev=1 / spec.embedding_dim**.5)))
  547. self._linked_feature_dims[spec.name] = spec.size * spec.embedding_dim
  548. else:
  549. # If embedding_dim is -1, linked features are not embedded.
  550. self._linked_feature_dims[spec.name] = spec.size * source_array_dim
  551. # Compute the cumulative dimension of all inputs. If any input has dynamic
  552. # dimension, then the result is -1.
  553. input_dims = (self._fixed_feature_dims.values() +
  554. self._linked_feature_dims.values())
  555. if any(x < 0 for x in input_dims):
  556. self._concatenated_input_dim = -1
  557. else:
  558. self._concatenated_input_dim = sum(input_dims)
  559. tf.logging.info('component %s concat_input_dim %s', component.name,
  560. self._concatenated_input_dim)
  561. # Allocate attention parameters.
  562. if self._component.spec.attention_component:
  563. attention_source_component = self._component.master.lookup_component[
  564. self._component.spec.attention_component]
  565. attention_hidden_layer_sizes = map(
  566. int, attention_source_component.spec.network_unit.parameters[
  567. 'hidden_layer_sizes'].split(','))
  568. attention_hidden_layer_size = attention_hidden_layer_sizes[-1]
  569. hidden_layer_sizes = map(int, component.spec.network_unit.parameters[
  570. 'hidden_layer_sizes'].split(','))
  571. # The attention function is built on the last layer of hidden embeddings.
  572. hidden_layer_size = hidden_layer_sizes[-1]
  573. self._params.append(
  574. tf.get_variable(
  575. 'attention_weights_pm_0',
  576. [attention_hidden_layer_size, hidden_layer_size],
  577. initializer=tf.random_normal_initializer(stddev=1e-4)))
  578. self._params.append(
  579. tf.get_variable(
  580. 'attention_weights_hm_0', [hidden_layer_size, hidden_layer_size],
  581. initializer=tf.random_normal_initializer(stddev=1e-4)))
  582. self._params.append(
  583. tf.get_variable(
  584. 'attention_bias_0', [1, hidden_layer_size],
  585. initializer=tf.zeros_initializer()))
  586. self._params.append(
  587. tf.get_variable(
  588. 'attention_bias_1', [1, hidden_layer_size],
  589. initializer=tf.zeros_initializer()))
  590. self._params.append(
  591. tf.get_variable(
  592. 'attention_weights_pu',
  593. [attention_hidden_layer_size, component.num_actions],
  594. initializer=tf.random_normal_initializer(stddev=1e-4)))
  595. @abstractmethod
  596. def create(self,
  597. fixed_embeddings,
  598. linked_embeddings,
  599. context_tensor_arrays,
  600. attention_tensor,
  601. during_training,
  602. stride=None):
  603. """Constructs a feed-forward unit based on the features and context tensors.
  604. Args:
  605. fixed_embeddings: list of NamedTensor objects
  606. linked_embeddings: list of NamedTensor objects
  607. context_tensor_arrays: optional list of TensorArray objects used for
  608. implicit recurrence.
  609. attention_tensor: optional Tensor used for attention.
  610. during_training: whether to create a network for training (vs inference).
  611. stride: int scalar tensor containing the stride required for
  612. bulk computation.
  613. Returns:
  614. A list of tensors corresponding to the list of layers.
  615. """
  616. pass
  617. @property
  618. def layers(self):
  619. return self._layers
  620. @property
  621. def params(self):
  622. return self._params
  623. @property
  624. def regularized_weights(self):
  625. return self._regularized_weights
  626. @property
  627. def context_layers(self):
  628. return self._context_layers
  629. def get_layer_index(self, layer_name):
  630. """Gets the index of the given named layer of the network."""
  631. return [x.name for x in self.layers].index(layer_name)
  632. def get_layer_size(self, layer_name):
  633. """Gets the size of the given named layer of the network.
  634. Args:
  635. layer_name: string name of layer to look update
  636. Returns:
  637. the size of the layer.
  638. Raises:
  639. KeyError: if the layer_name to look up doesn't exist.
  640. """
  641. for layer in self.layers:
  642. if layer.name == layer_name:
  643. return layer.dim
  644. raise KeyError('Layer {} not found in component {}'.format(
  645. layer_name, self._component.name))
  646. def get_logits(self, network_tensors):
  647. """Pulls out the logits from the tensors produced by this unit.
  648. Args:
  649. network_tensors: list of tensors as output by create().
  650. Raises:
  651. NotImplementedError: by default a 'logits' tensor need not be implemented.
  652. """
  653. raise NotImplementedError()
  654. def get_l2_regularized_weights(self):
  655. """Gets the weights that need to be regularized."""
  656. return self.regularized_weights
  657. def attention(self, last_layer, attention_tensor):
  658. """Compute the attention term for the network unit."""
  659. h_tensor = attention_tensor
  660. # Compute the attentions.
  661. # Using feed-forward net to map the two inputs into the same dimension
  662. focus_tensor = tf.nn.tanh(
  663. tf.matmul(
  664. h_tensor,
  665. self._component.get_variable('attention_weights_pm_0'),
  666. name='h_x_pm') + self._component.get_variable('attention_bias_0'))
  667. context_tensor = tf.nn.tanh(
  668. tf.matmul(
  669. last_layer,
  670. self._component.get_variable('attention_weights_hm_0'),
  671. name='l_x_hm') + self._component.get_variable('attention_bias_1'))
  672. # The tf.multiply in the following expression broadcasts along the 0 dim:
  673. z_vec = tf.reduce_sum(tf.multiply(focus_tensor, context_tensor), 1)
  674. p_vec = tf.nn.softmax(tf.reshape(z_vec, [1, -1]))
  675. # The tf.multiply in the following expression broadcasts along the 1 dim:
  676. r_vec = tf.expand_dims(
  677. tf.reduce_sum(
  678. tf.multiply(
  679. h_tensor, tf.reshape(p_vec, [-1, 1]), name='time_together2'),
  680. 0),
  681. 0)
  682. return tf.matmul(
  683. r_vec,
  684. self._component.get_variable('attention_weights_pu'),
  685. name='time_together3')
  686. class IdentityNetwork(NetworkUnitInterface):
  687. """A network that returns concatenated input embeddings and activations."""
  688. def __init__(self, component):
  689. super(IdentityNetwork, self).__init__(component)
  690. self._layers = [
  691. Layer(
  692. component,
  693. name='input_embeddings',
  694. dim=self._concatenated_input_dim)
  695. ]
  696. def create(self,
  697. fixed_embeddings,
  698. linked_embeddings,
  699. context_tensor_arrays,
  700. attention_tensor,
  701. during_training,
  702. stride=None):
  703. return [get_input_tensor(fixed_embeddings, linked_embeddings)]
  704. def get_layer_size(self, layer_name):
  705. # Note that get_layer_size is called by super.__init__ before any layers are
  706. # constructed if and only if there are recurrent links.
  707. assert hasattr(self,
  708. '_layers'), 'IdentityNetwork cannot have recurrent links'
  709. return super(IdentityNetwork, self).get_layer_size(layer_name)
  710. def get_logits(self, network_tensors):
  711. return network_tensors[-1]
  712. def get_context_layers(self):
  713. return []
  714. class FeedForwardNetwork(NetworkUnitInterface):
  715. """Implementation of C&M style feedforward network.
  716. Supports dropout and optional layer normalization.
  717. Layers:
  718. layer_<i>: Activations for i'th hidden layer (0-origin).
  719. last_layer: Activations for the last hidden layer. This is a convenience
  720. alias for "layer_<n-1>", where n is the number of hidden layers.
  721. logits: Logits associated with component actions.
  722. """
  723. def __init__(self, component):
  724. """Initializes parameters required to run this network.
  725. Args:
  726. component: parent ComponentBuilderBase object.
  727. Parameters used to construct the network:
  728. hidden_layer_sizes: comma-separated list of ints, indicating the
  729. number of hidden units in each hidden layer.
  730. layer_norm_input (False): Whether or not to apply layer normalization
  731. on the concatenated input to the network.
  732. layer_norm_hidden (False): Whether or not to apply layer normalization
  733. to the first set of hidden layer activations.
  734. nonlinearity ('relu'): Name of function from module "tf.nn" to apply to
  735. each hidden layer; e.g., "relu" or "elu".
  736. dropout_keep_prob (-1.0): The probability that an input is not dropped.
  737. If >= 1.0, disables dropout. If < 0.0, uses the global |dropout_rate|
  738. hyperparameter.
  739. dropout_per_sequence (False): If true, sample the dropout mask once per
  740. sequence, instead of once per step. See Gal and Ghahramani
  741. (https://arxiv.org/abs/1512.05287).
  742. dropout_all_layers (False): If true, apply dropout to the input of all
  743. hidden layers, instead of just applying it to the network input.
  744. Hyperparameters used:
  745. dropout_rate: The probability that an input is not dropped. Only used
  746. when the |dropout_keep_prob| parameter is negative.
  747. """
  748. self._attrs = get_attrs_with_defaults(
  749. component.spec.network_unit.parameters, defaults={
  750. 'hidden_layer_sizes': '',
  751. 'layer_norm_input': False,
  752. 'layer_norm_hidden': False,
  753. 'nonlinearity': 'relu',
  754. 'dropout_keep_prob': -1.0,
  755. 'dropout_per_sequence': False,
  756. 'dropout_all_layers': False})
  757. # Initialize the hidden layer sizes before running the base initializer, as
  758. # the base initializer may need to know the size of of the hidden layer for
  759. # recurrent connections.
  760. self._hidden_layer_sizes = (
  761. map(int, self._attrs['hidden_layer_sizes'].split(','))
  762. if self._attrs['hidden_layer_sizes'] else [])
  763. super(FeedForwardNetwork, self).__init__(component)
  764. # Infer dropout rate from network parameters and grid hyperparameters.
  765. self._dropout_rate = self._attrs['dropout_keep_prob']
  766. if self._dropout_rate < 0.0:
  767. self._dropout_rate = component.master.hyperparams.dropout_rate
  768. # Add layer norm if specified.
  769. self._layer_norm_input = None
  770. self._layer_norm_hidden = None
  771. if self._attrs['layer_norm_input']:
  772. self._layer_norm_input = LayerNorm(self._component, 'concat_input',
  773. self._concatenated_input_dim,
  774. tf.float32)
  775. self._params.extend(self._layer_norm_input.params)
  776. if self._attrs['layer_norm_hidden']:
  777. self._layer_norm_hidden = LayerNorm(self._component, 'layer_0',
  778. self._hidden_layer_sizes[0],
  779. tf.float32)
  780. self._params.extend(self._layer_norm_hidden.params)
  781. # Extract nonlinearity from |tf.nn|.
  782. self._nonlinearity = getattr(tf.nn, self._attrs['nonlinearity'])
  783. # TODO(googleuser): add initializer stddevs as part of the network unit's
  784. # configuration.
  785. self._weights = []
  786. last_layer_dim = self._concatenated_input_dim
  787. # Initialize variables for the parameters, and add Layer objects for
  788. # cross-component bookkeeping.
  789. for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
  790. weights = tf.get_variable(
  791. 'weights_%d' % index, [last_layer_dim, hidden_layer_size],
  792. initializer=tf.random_normal_initializer(stddev=1e-4))
  793. self._params.append(weights)
  794. if index > 0 or self._layer_norm_hidden is None:
  795. self._params.append(
  796. tf.get_variable(
  797. 'bias_%d' % index, [hidden_layer_size],
  798. initializer=tf.constant_initializer(
  799. 0.2, dtype=tf.float32)))
  800. self._weights.append(weights)
  801. self._layers.append(
  802. Layer(
  803. component, name='layer_%d' % index, dim=hidden_layer_size))
  804. last_layer_dim = hidden_layer_size
  805. # Add a convenience alias for the last hidden layer, if any.
  806. if self._hidden_layer_sizes:
  807. self._layers.append(Layer(component, 'last_layer', last_layer_dim))
  808. # By default, regularize only the weights.
  809. self._regularized_weights.extend(self._weights)
  810. if component.num_actions:
  811. self._params.append(
  812. tf.get_variable(
  813. 'weights_softmax', [last_layer_dim, component.num_actions],
  814. initializer=tf.random_normal_initializer(stddev=1e-4)))
  815. self._params.append(
  816. tf.get_variable(
  817. 'bias_softmax', [component.num_actions],
  818. initializer=tf.zeros_initializer()))
  819. self._layers.append(
  820. Layer(
  821. component, name='logits', dim=component.num_actions))
  822. def create(self,
  823. fixed_embeddings,
  824. linked_embeddings,
  825. context_tensor_arrays,
  826. attention_tensor,
  827. during_training,
  828. stride=None):
  829. """See base class."""
  830. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  831. if during_training:
  832. input_tensor.set_shape([None, self._concatenated_input_dim])
  833. input_tensor = self._maybe_apply_dropout(input_tensor, stride)
  834. if self._layer_norm_input:
  835. input_tensor = self._layer_norm_input.normalize(input_tensor)
  836. tensors = []
  837. last_layer = input_tensor
  838. for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
  839. acts = tf.matmul(last_layer,
  840. self._component.get_variable('weights_%d' % index))
  841. # Note that the first layer was already handled before this loop.
  842. # TODO(googleuser): Refactor this loop so dropout and layer normalization
  843. # are applied consistently.
  844. if during_training and self._attrs['dropout_all_layers'] and index > 0:
  845. acts.set_shape([None, hidden_layer_size])
  846. acts = self._maybe_apply_dropout(acts, stride)
  847. # Don't add a bias term if we're going to apply layer norm, since layer
  848. # norm includes a bias already.
  849. if index == 0 and self._layer_norm_hidden:
  850. acts = self._layer_norm_hidden.normalize(acts)
  851. else:
  852. acts = tf.nn.bias_add(acts,
  853. self._component.get_variable('bias_%d' % index))
  854. last_layer = self._nonlinearity(acts)
  855. tensors.append(last_layer)
  856. # Add a convenience alias for the last hidden layer, if any.
  857. if self._hidden_layer_sizes:
  858. tensors.append(last_layer)
  859. if self._layers[-1].name == 'logits':
  860. logits = tf.matmul(
  861. last_layer, self._component.get_variable(
  862. 'weights_softmax')) + self._component.get_variable('bias_softmax')
  863. if self._component.spec.attention_component:
  864. logits += self.attention(last_layer, attention_tensor)
  865. logits = tf.identity(logits, name=self._layers[-1].name)
  866. tensors.append(logits)
  867. return tensors
  868. def get_layer_size(self, layer_name):
  869. if layer_name == 'logits':
  870. return self._component.num_actions
  871. if layer_name == 'last_layer':
  872. return self._hidden_layer_sizes[-1]
  873. if not layer_name.startswith('layer_'):
  874. logging.fatal(
  875. 'Invalid layer name: "%s" Can only retrieve from "logits", '
  876. '"last_layer", and "layer_*".',
  877. layer_name)
  878. # NOTE(danielandor): Since get_layer_size is called before the
  879. # model has been built, we compute the layer size directly from
  880. # the hyperparameters rather than from self._layers.
  881. layer_index = int(layer_name.split('_')[1])
  882. return self._hidden_layer_sizes[layer_index]
  883. def get_logits(self, network_tensors):
  884. return network_tensors[-1]
  885. def _maybe_apply_dropout(self, inputs, stride):
  886. return maybe_apply_dropout(inputs, self._dropout_rate,
  887. self._attrs['dropout_per_sequence'], stride)
  888. class LSTMNetwork(NetworkUnitInterface):
  889. """Implementation of action LSTM style network."""
  890. def __init__(self, component):
  891. assert component.num_actions > 0, 'Component num actions must be positive.'
  892. network_unit_spec = component.spec.network_unit
  893. self._hidden_layer_sizes = (
  894. int)(network_unit_spec.parameters['hidden_layer_sizes'])
  895. self._input_dropout_rate = component.master.hyperparams.dropout_rate
  896. self._recurrent_dropout_rate = (
  897. component.master.hyperparams.recurrent_dropout_rate)
  898. if self._recurrent_dropout_rate < 0.0:
  899. self._recurrent_dropout_rate = component.master.hyperparams.dropout_rate
  900. super(LSTMNetwork, self).__init__(component)
  901. layer_input_dim = self._concatenated_input_dim
  902. self._context_layers = []
  903. # TODO(googleuser): should we choose different initilizer,
  904. # e.g. truncated_normal_initializer?
  905. self._x2i = tf.get_variable(
  906. 'x2i', [layer_input_dim, self._hidden_layer_sizes],
  907. initializer=tf.random_normal_initializer(stddev=1e-4))
  908. self._h2i = tf.get_variable(
  909. 'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  910. initializer=tf.random_normal_initializer(stddev=1e-4))
  911. self._c2i = tf.get_variable(
  912. 'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  913. initializer=tf.random_normal_initializer(stddev=1e-4))
  914. self._bi = tf.get_variable(
  915. 'bi', [self._hidden_layer_sizes],
  916. initializer=tf.random_normal_initializer(stddev=1e-4))
  917. self._x2o = tf.get_variable(
  918. 'x2o', [layer_input_dim, self._hidden_layer_sizes],
  919. initializer=tf.random_normal_initializer(stddev=1e-4))
  920. self._h2o = tf.get_variable(
  921. 'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  922. initializer=tf.random_normal_initializer(stddev=1e-4))
  923. self._c2o = tf.get_variable(
  924. 'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  925. initializer=tf.random_normal_initializer(stddev=1e-4))
  926. self._bo = tf.get_variable(
  927. 'bo', [self._hidden_layer_sizes],
  928. initializer=tf.random_normal_initializer(stddev=1e-4))
  929. self._x2c = tf.get_variable(
  930. 'x2c', [layer_input_dim, self._hidden_layer_sizes],
  931. initializer=tf.random_normal_initializer(stddev=1e-4))
  932. self._h2c = tf.get_variable(
  933. 'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  934. initializer=tf.random_normal_initializer(stddev=1e-4))
  935. self._bc = tf.get_variable(
  936. 'bc', [self._hidden_layer_sizes],
  937. initializer=tf.random_normal_initializer(stddev=1e-4))
  938. self._params.extend([
  939. self._x2i, self._h2i, self._c2i, self._bi, self._x2o, self._h2o,
  940. self._c2o, self._bo, self._x2c, self._h2c, self._bc])
  941. lstm_h_layer = Layer(component, name='lstm_h', dim=self._hidden_layer_sizes)
  942. lstm_c_layer = Layer(component, name='lstm_c', dim=self._hidden_layer_sizes)
  943. self._context_layers.append(lstm_h_layer)
  944. self._context_layers.append(lstm_c_layer)
  945. self._layers.extend(self._context_layers)
  946. self._layers.append(
  947. Layer(
  948. component, name='layer_0', dim=self._hidden_layer_sizes))
  949. self.params.append(tf.get_variable(
  950. 'weights_softmax', [self._hidden_layer_sizes, component.num_actions],
  951. initializer=tf.random_normal_initializer(stddev=1e-4)))
  952. self.params.append(
  953. tf.get_variable(
  954. 'bias_softmax', [component.num_actions],
  955. initializer=tf.zeros_initializer()))
  956. self._layers.append(
  957. Layer(
  958. component, name='logits', dim=component.num_actions))
  959. def create(self,
  960. fixed_embeddings,
  961. linked_embeddings,
  962. context_tensor_arrays,
  963. attention_tensor,
  964. during_training,
  965. stride=None):
  966. """See base class."""
  967. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  968. # context_tensor_arrays[0] is lstm_h
  969. # context_tensor_arrays[1] is lstm_c
  970. assert len(context_tensor_arrays) == 2
  971. length = context_tensor_arrays[0].size()
  972. # Get the (possibly averaged) parameters to execute the network.
  973. x2i = self._component.get_variable('x2i')
  974. h2i = self._component.get_variable('h2i')
  975. c2i = self._component.get_variable('c2i')
  976. bi = self._component.get_variable('bi')
  977. x2o = self._component.get_variable('x2o')
  978. h2o = self._component.get_variable('h2o')
  979. c2o = self._component.get_variable('c2o')
  980. bo = self._component.get_variable('bo')
  981. x2c = self._component.get_variable('x2c')
  982. h2c = self._component.get_variable('h2c')
  983. bc = self._component.get_variable('bc')
  984. # i_h_tm1, i_c_tm1 = h_{t-1}, c_{t-1}
  985. i_h_tm1 = context_tensor_arrays[0].read(length - 1)
  986. i_c_tm1 = context_tensor_arrays[1].read(length - 1)
  987. # apply dropout according to http://arxiv.org/pdf/1409.2329v5.pdf
  988. if during_training and self._input_dropout_rate < 1:
  989. input_tensor = tf.nn.dropout(input_tensor, self._input_dropout_rate)
  990. # input -- i_t = sigmoid(affine(x_t, h_{t-1}, c_{t-1}))
  991. i_ait = tf.matmul(input_tensor, x2i) + tf.matmul(i_h_tm1, h2i) + tf.matmul(
  992. i_c_tm1, c2i) + bi
  993. i_it = tf.sigmoid(i_ait)
  994. # forget -- f_t = 1 - i_t
  995. i_ft = tf.ones([1, 1]) - i_it
  996. # write memory cell -- tanh(affine(x_t, h_{t-1}))
  997. i_awt = tf.matmul(input_tensor, x2c) + tf.matmul(i_h_tm1, h2c) + bc
  998. i_wt = tf.tanh(i_awt)
  999. # c_t = f_t \odot c_{t-1} + i_t \odot tanh(affine(x_t, h_{t-1}))
  1000. ct = tf.add(
  1001. tf.multiply(i_it, i_wt), tf.multiply(i_ft, i_c_tm1), name='lstm_c')
  1002. # output -- o_t = sigmoid(affine(x_t, h_{t-1}, c_t))
  1003. i_aot = tf.matmul(input_tensor, x2o) + tf.matmul(ct, c2o) + tf.matmul(
  1004. i_h_tm1, h2o) + bo
  1005. i_ot = tf.sigmoid(i_aot)
  1006. # ht = o_t \odot tanh(ct)
  1007. ph_t = tf.tanh(ct)
  1008. ht = tf.multiply(i_ot, ph_t, name='lstm_h')
  1009. if during_training and self._recurrent_dropout_rate < 1:
  1010. ht = tf.nn.dropout(
  1011. ht, self._recurrent_dropout_rate, name='lstm_h_dropout')
  1012. h = tf.identity(ht, name='layer_0')
  1013. logits = tf.nn.xw_plus_b(ht, tf.get_variable('weights_softmax'),
  1014. tf.get_variable('bias_softmax'))
  1015. if self._component.spec.attention_component:
  1016. logits += self.attention(ht, attention_tensor)
  1017. logits = tf.identity(logits, name='logits')
  1018. # tensors will be consistent with the layers:
  1019. # [lstm_h, lstm_c, layer_0, logits]
  1020. tensors = [ht, ct, h, logits]
  1021. return tensors
  1022. def get_layer_size(self, layer_name):
  1023. assert layer_name == 'layer_0', 'Can only retrieve from first hidden layer.'
  1024. return self._hidden_layer_sizes
  1025. def get_logits(self, network_tensors):
  1026. return network_tensors[self.get_layer_index('logits')]
  1027. class ConvNetwork(NetworkUnitInterface):
  1028. """Implementation of a convolutional feed forward network."""
  1029. def __init__(self, component):
  1030. """Initializes kernels and biases for this convolutional net.
  1031. Args:
  1032. component: parent ComponentBuilderBase object.
  1033. Parameters used to construct the network:
  1034. widths: comma separated list of ints, number of steps input to the
  1035. convolutional kernel at every layer.
  1036. depths: comma separated list of ints, number of channels input to the
  1037. convolutional kernel at every layer.
  1038. output_embedding_dim: int, number of output channels for the convolutional
  1039. kernel of the last layer, which receives no ReLU activation and
  1040. therefore can be used in a softmax output. If zero, this final
  1041. layer is disabled entirely.
  1042. nonlinearity ('relu'): Name of function from module "tf.nn" to apply to
  1043. each hidden layer; e.g., "relu" or "elu".
  1044. dropout_keep_prob (-1.0): The probability that an input is not dropped.
  1045. If >= 1.0, disables dropout. If < 0.0, uses the global |dropout_rate|
  1046. hyperparameter.
  1047. dropout_per_sequence (False): If true, sample the dropout mask once per
  1048. sequence, instead of once per step. See Gal and Ghahramani
  1049. (https://arxiv.org/abs/1512.05287).
  1050. Hyperparameters used:
  1051. dropout_rate: The probability that an input is not dropped. Only used
  1052. when the |dropout_keep_prob| parameter is negative.
  1053. """
  1054. super(ConvNetwork, self).__init__(component)
  1055. self._attrs = get_attrs_with_defaults(
  1056. component.spec.network_unit.parameters, defaults={
  1057. 'widths': '',
  1058. 'depths': '',
  1059. 'output_embedding_dim': 0,
  1060. 'nonlinearity': 'relu',
  1061. 'dropout_keep_prob': -1.0,
  1062. 'dropout_per_sequence': False})
  1063. self._weights = []
  1064. self._biases = []
  1065. self._widths = map(int, self._attrs['widths'].split(','))
  1066. self._depths = map(int, self._attrs['depths'].split(','))
  1067. self._output_dim = self._attrs['output_embedding_dim']
  1068. if self._output_dim:
  1069. self._depths.append(self._output_dim)
  1070. self.kernel_shapes = []
  1071. for i in range(len(self._depths) - 1):
  1072. self.kernel_shapes.append(
  1073. [1, self._widths[i], self._depths[i], self._depths[i + 1]])
  1074. for i in range(len(self._depths) - 1):
  1075. with tf.variable_scope('conv%d' % i):
  1076. self._weights.append(
  1077. tf.get_variable(
  1078. 'weights',
  1079. self.kernel_shapes[i],
  1080. initializer=tf.random_normal_initializer(stddev=1e-4),
  1081. dtype=tf.float32))
  1082. bias_init = 0.0 if (i == len(self._widths) - 1) else 0.2
  1083. self._biases.append(
  1084. tf.get_variable(
  1085. 'biases',
  1086. self.kernel_shapes[i][-1],
  1087. initializer=tf.constant_initializer(bias_init),
  1088. dtype=tf.float32))
  1089. # Extract nonlinearity from |tf.nn|.
  1090. self._nonlinearity = getattr(tf.nn, self._attrs['nonlinearity'])
  1091. # Infer dropout rate from network parameters and grid hyperparameters.
  1092. self._dropout_rate = self._attrs['dropout_keep_prob']
  1093. if self._dropout_rate < 0.0:
  1094. self._dropout_rate = component.master.hyperparams.dropout_rate
  1095. self._params.extend(self._weights + self._biases)
  1096. self._layers.append(
  1097. Layer(
  1098. component, name='conv_output', dim=self._depths[-1]))
  1099. self._regularized_weights.extend(self._weights[:-1] if self._output_dim else
  1100. self._weights)
  1101. def create(self,
  1102. fixed_embeddings,
  1103. linked_embeddings,
  1104. context_tensor_arrays,
  1105. attention_tensor,
  1106. during_training,
  1107. stride=None):
  1108. """Requires |stride|; otherwise see base class."""
  1109. if stride is None:
  1110. raise RuntimeError("ConvNetwork needs 'stride' and must be called in the "
  1111. "bulk feature extractor component.")
  1112. input_tensor = get_input_tensor_with_stride(fixed_embeddings,
  1113. linked_embeddings, stride)
  1114. # TODO(googleuser): Add context and attention.
  1115. del context_tensor_arrays, attention_tensor
  1116. # On CPU, add a dimension so that the 'image' has shape
  1117. # [stride, 1, num_steps, D].
  1118. conv = tf.expand_dims(input_tensor, 1)
  1119. for i in range(len(self._depths) - 1):
  1120. with tf.variable_scope('conv%d' % i, reuse=True) as scope:
  1121. if during_training:
  1122. conv.set_shape([None, 1, None, self._depths[i]])
  1123. conv = self._maybe_apply_dropout(conv, stride)
  1124. conv = tf.nn.conv2d(
  1125. conv,
  1126. self._component.get_variable('weights'), [1, 1, 1, 1],
  1127. padding='SAME')
  1128. conv = tf.nn.bias_add(conv, self._component.get_variable('biases'))
  1129. if i < (len(self._weights) - 1) or not self._output_dim:
  1130. conv = self._nonlinearity(conv, name=scope.name)
  1131. return [
  1132. tf.reshape(
  1133. conv, [-1, self._depths[-1]], name='reshape_activations')
  1134. ]
  1135. def _maybe_apply_dropout(self, inputs, stride):
  1136. # The |inputs| are rank 4 (one 1xN "image" per sequence). Squeeze out and
  1137. # restore the singleton image height, so dropout is applied to the normal
  1138. # rank 3 batched input tensor.
  1139. inputs = tf.squeeze(inputs, [1])
  1140. inputs = maybe_apply_dropout(inputs, self._dropout_rate,
  1141. self._attrs['dropout_per_sequence'], stride)
  1142. inputs = tf.expand_dims(inputs, 1)
  1143. return inputs
  1144. class PairwiseConvNetwork(NetworkUnitInterface):
  1145. """Implementation of a pairwise 2D convolutional feed forward network.
  1146. For a sequence of N tokens, all N^2 pairs of concatenated input features are
  1147. constructed. If each input vector is of length D, then the sequence is
  1148. represented by an image of dimensions [N, N] with 2*D channels per pixel.
  1149. I.e. pixel [i, j] has a representation that is the concatenation of the
  1150. representations of the tokens at i and at j.
  1151. To use this network for graph edge scoring, for instance by using the "heads"
  1152. transition system, the output layer needs to have dimensions [N, N] and only
  1153. a single channel. The network takes care of outputting an [N, N] sized layer,
  1154. but the user needs to ensure that the output depth equals 1.
  1155. TODO(googleuser): Like Dozat and Manning, we will need an
  1156. additional network to label the edges, and the ability to read head
  1157. and modifier representations from different inputs.
  1158. """
  1159. def __init__(self, component):
  1160. """Initializes kernels and biases for this convolutional net.
  1161. Parameters used to construct the network:
  1162. depths: comma separated list of ints, number of channels input to the
  1163. convolutional kernel at every layer.
  1164. widths: comma separated list of ints, number of steps input to the
  1165. convolutional kernel at every layer.
  1166. relu_layers: comma separate list of ints, the id of layers after which
  1167. to apply a relu activation. *By default, all but the final layer will
  1168. have a relu activation applied.*
  1169. To generate a network with M layers, both 'depths' and 'widths' must be of
  1170. length M. The input depth of the first layer is inferred from the total
  1171. concatenated size of the input features.
  1172. Args:
  1173. component: parent ComponentBuilderBase object.
  1174. Raises:
  1175. RuntimeError: if the number of depths and weights are not equal.
  1176. ValueError: if the final depth is not equal to 1.
  1177. """
  1178. parameters = component.spec.network_unit.parameters
  1179. super(PairwiseConvNetwork, self).__init__(component)
  1180. # Each input pixel will comprise the concatenation of two tokens, so the
  1181. # input depth is double that for a single token.
  1182. self._depths = [self._concatenated_input_dim * 2]
  1183. self._depths.extend(map(int, parameters['depths'].split(',')))
  1184. self._widths = map(int, parameters['widths'].split(','))
  1185. self._num_layers = len(self._widths)
  1186. if len(self._depths) != self._num_layers + 1:
  1187. raise RuntimeError('Unmatched depths/weights %s/%s' %
  1188. (parameters['depths'], parameters['weights']))
  1189. if self._depths[-1] != 1:
  1190. raise ValueError('Final depth is not equal to 1 in %s' %
  1191. parameters['depths'])
  1192. self._kernel_shapes = []
  1193. for i, width in enumerate(self._widths):
  1194. self._kernel_shapes.append(
  1195. [width, width, self._depths[i], self._depths[i + 1]])
  1196. if parameters['relu_layers']:
  1197. self._relu_layers = set(map(int, parameters['relu_layers'].split(',')))
  1198. else:
  1199. self._relu_layers = set(range(self._num_layers - 1))
  1200. self._weights = []
  1201. self._biases = []
  1202. for i, kernel_shape in enumerate(self._kernel_shapes):
  1203. with tf.variable_scope('conv%d' % i):
  1204. self._weights.append(
  1205. tf.get_variable(
  1206. 'weights',
  1207. kernel_shape,
  1208. initializer=tf.random_normal_initializer(stddev=1e-4),
  1209. dtype=tf.float32))
  1210. bias_init = 0.0 if i in self._relu_layers else 0.2
  1211. self._biases.append(
  1212. tf.get_variable(
  1213. 'biases',
  1214. kernel_shape[-1],
  1215. initializer=tf.constant_initializer(bias_init),
  1216. dtype=tf.float32))
  1217. self._params.extend(self._weights + self._biases)
  1218. self._layers.append(Layer(component, name='conv_output', dim=-1))
  1219. self._regularized_weights.extend(self._weights[:-1])
  1220. def create(self,
  1221. fixed_embeddings,
  1222. linked_embeddings,
  1223. context_tensor_arrays,
  1224. attention_tensor,
  1225. during_training,
  1226. stride=None):
  1227. """Requires |stride|; otherwise see base class."""
  1228. # TODO(googleuser): Normalize the arguments to create(). 'stride'
  1229. # is unused by the recurrent network units, while 'context_tensor_arrays'
  1230. # and 'attenion_tensor_array' is unused by bulk network units. b/33587044
  1231. if stride is None:
  1232. raise ValueError("PairwiseConvNetwork needs 'stride'")
  1233. input_tensor = get_input_tensor_with_stride(fixed_embeddings,
  1234. linked_embeddings, stride)
  1235. # TODO(googleuser): Add dropout.
  1236. del context_tensor_arrays, attention_tensor, during_training # Unused.
  1237. num_steps = tf.shape(input_tensor)[1]
  1238. arg1 = tf.expand_dims(input_tensor, 1)
  1239. arg1 = tf.tile(arg1, tf.stack([1, num_steps, 1, 1]))
  1240. arg2 = tf.expand_dims(input_tensor, 2)
  1241. arg2 = tf.tile(arg2, tf.stack([1, 1, num_steps, 1]))
  1242. conv = tf.concat([arg1, arg2], 3)
  1243. for i in xrange(self._num_layers):
  1244. with tf.variable_scope('conv%d' % i, reuse=True) as scope:
  1245. conv = tf.nn.conv2d(
  1246. conv,
  1247. self._component.get_variable('weights'), [1, 1, 1, 1],
  1248. padding='SAME')
  1249. conv = tf.nn.bias_add(conv, self._component.get_variable('biases'))
  1250. if i in self._relu_layers:
  1251. conv = tf.nn.relu(conv, name=scope.name)
  1252. return [tf.reshape(conv, [-1, num_steps], name='reshape_activations')]
  1253. class ExportFixedFeaturesNetwork(NetworkUnitInterface):
  1254. """A network that exports fixed features as layers.
  1255. Each fixed feature embedding is output as a layer whose name and dimension are
  1256. set to the name and dimension of the corresponding fixed feature.
  1257. """
  1258. def __init__(self, component):
  1259. """Initializes exported layers."""
  1260. super(ExportFixedFeaturesNetwork, self).__init__(component)
  1261. for feature_spec in component.spec.fixed_feature:
  1262. name = feature_spec.name
  1263. dim = self._fixed_feature_dims[name]
  1264. self._layers.append(Layer(component, name, dim))
  1265. def create(self,
  1266. fixed_embeddings,
  1267. linked_embeddings,
  1268. context_tensor_arrays,
  1269. attention_tensor,
  1270. during_training,
  1271. stride=None):
  1272. """See base class."""
  1273. check.Eq(len(self.layers), len(fixed_embeddings))
  1274. for index in range(len(fixed_embeddings)):
  1275. check.Eq(self.layers[index].name, fixed_embeddings[index].name)
  1276. return [fixed_embedding.tensor for fixed_embedding in fixed_embeddings]
  1277. class SplitNetwork(NetworkUnitInterface):
  1278. """Network unit that splits its input into slices of equal dimension.
  1279. Parameters:
  1280. num_slices: The number of slices to split the input into, S. The input must
  1281. have static dimension D, where D % S == 0.
  1282. Features:
  1283. All inputs are concatenated before being split.
  1284. Layers:
  1285. slice_0: [B * N, D / S] The first slice of the input.
  1286. slice_1: [B * N, D / S] The second slice of the input.
  1287. ...
  1288. """
  1289. def __init__(self, component):
  1290. """Initializes weights and layers.
  1291. Args:
  1292. component: Parent ComponentBuilderBase object.
  1293. """
  1294. super(SplitNetwork, self).__init__(component)
  1295. parameters = component.spec.network_unit.parameters
  1296. self._num_slices = int(parameters['num_slices'])
  1297. check.Gt(self._num_slices, 0, 'Invalid number of slices.')
  1298. check.Eq(self._concatenated_input_dim % self._num_slices, 0,
  1299. 'Input dimension %s does not evenly divide into %s slices' %
  1300. (self._concatenated_input_dim, self._num_slices))
  1301. self._slice_dim = int(self._concatenated_input_dim / self._num_slices)
  1302. for slice_index in xrange(self._num_slices):
  1303. self._layers.append(
  1304. Layer(self, 'slice_%s' % slice_index, self._slice_dim))
  1305. def create(self,
  1306. fixed_embeddings,
  1307. linked_embeddings,
  1308. context_tensor_arrays,
  1309. attention_tensor,
  1310. during_training,
  1311. stride=None):
  1312. input_bnxd = get_input_tensor(fixed_embeddings, linked_embeddings)
  1313. return tf.split(input_bnxd, self._num_slices, axis=1)