network_units.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603
  1. """Basic network units used in assembling DRAGNN graphs."""
  2. from abc import ABCMeta
  3. from abc import abstractmethod
  4. import tensorflow as tf
  5. from tensorflow.python.ops import nn
  6. from tensorflow.python.ops import tensor_array_ops as ta
  7. from tensorflow.python.platform import tf_logging as logging
  8. from dragnn.python import dragnn_ops
  9. from syntaxnet.util import check
  10. from syntaxnet.util import registry
  11. def linked_embeddings_name(channel_id):
  12. """Returns the name of the linked embedding matrix for some channel ID."""
  13. return 'linked_embedding_matrix_%d' % channel_id
  14. def fixed_embeddings_name(channel_id):
  15. """Returns the name of the fixed embedding matrix for some channel ID."""
  16. return 'fixed_embedding_matrix_%d' % channel_id
  17. class StoredActivations(object):
  18. """Wrapper around stored activation vectors.
  19. Because activations are produced and consumed in different layouts by bulk
  20. vs. dynamic components, this class provides a simple common
  21. interface/conversion API. It can be constructed from either a TensorArray
  22. (dynamic) or a Tensor (bulk), and the resulting object to use for lookups is
  23. either bulk_tensor (for bulk components) or dynamic_tensor (for dynamic
  24. components).
  25. """
  26. def __init__(self, tensor=None, array=None, stride=None, dim=None):
  27. """Creates ops for converting the input to either format.
  28. If 'tensor' is used, then a conversion from [stride * steps, dim] to
  29. [steps + 1, stride, dim] is performed for dynamic_tensor reads.
  30. If 'array' is used, then a conversion from [steps + 1, stride, dim] to
  31. [stride * steps, dim] is performed for bulk_tensor reads.
  32. Args:
  33. tensor: Bulk tensor input.
  34. array: TensorArray dynamic input.
  35. stride: stride of bulk tensor. Not used for dynamic.
  36. dim: dim of bulk tensor. Not used for dynamic.
  37. """
  38. if tensor is not None:
  39. check.IsNone(array, 'Cannot initialize from tensor and array')
  40. check.NotNone(stride, 'Stride is required for bulk tensor')
  41. check.NotNone(dim, 'Dim is required for bulk tensor')
  42. self._bulk_tensor = tensor
  43. with tf.name_scope('convert_to_dyn'):
  44. tensor = tf.reshape(tensor, [stride, -1, dim])
  45. tensor = tf.transpose(tensor, perm=[1, 0, 2])
  46. pad = tf.zeros([1, stride, dim], dtype=tensor.dtype)
  47. self._array_tensor = tf.concat([pad, tensor], 0)
  48. if array is not None:
  49. check.IsNone(tensor, 'Cannot initialize from both tensor and array')
  50. with tf.name_scope('convert_to_bulk'):
  51. self._bulk_tensor = convert_network_state_tensorarray(array)
  52. with tf.name_scope('convert_to_dyn'):
  53. self._array_tensor = array.stack()
  54. @property
  55. def bulk_tensor(self):
  56. return self._bulk_tensor
  57. @property
  58. def dynamic_tensor(self):
  59. return self._array_tensor
  60. class NamedTensor(object):
  61. """Container for a tensor with associated name and dimension attributes."""
  62. def __init__(self, tensor, name, dim=None):
  63. """Inits NamedTensor with tensor, name and optional dim."""
  64. self.tensor = tensor
  65. self.name = name
  66. self.dim = dim
  67. def add_embeddings(channel_id, feature_spec, seed):
  68. """Adds a variable for the embedding of a given fixed feature.
  69. Supports pre-trained or randomly initialized embeddings In both cases, extra
  70. vector is reserved for out-of-vocabulary words, so the embedding matrix has
  71. the size of [feature_spec.vocabulary_size + 1, feature_spec.embedding_dim].
  72. Args:
  73. channel_id: Numeric id of the fixed feature channel
  74. feature_spec: Feature spec protobuf of type FixedFeatureChannel
  75. seed: used for random initializer
  76. Returns:
  77. tf.Variable object corresponding to the embedding for that feature.
  78. Raises:
  79. RuntimeError: if more the pretrained embeddings are specified in resources
  80. containing more than one part.
  81. """
  82. check.Gt(feature_spec.embedding_dim, 0,
  83. 'Embeddings requested for non-embedded feature: %s' % feature_spec)
  84. name = fixed_embeddings_name(channel_id)
  85. shape = [feature_spec.vocabulary_size + 1, feature_spec.embedding_dim]
  86. if feature_spec.HasField('pretrained_embedding_matrix'):
  87. if len(feature_spec.pretrained_embedding_matrix.part) > 1:
  88. raise RuntimeError('pretrained_embedding_matrix resource contains '
  89. 'more than one part:\n%s',
  90. str(feature_spec.pretrained_embedding_matrix))
  91. if len(feature_spec.vocab.part) > 1:
  92. raise RuntimeError('vocab resource contains more than one part:\n%s',
  93. str(feature_spec.vocab))
  94. embeddings = dragnn_ops.dragnn_embedding_initializer(
  95. embedding_input=feature_spec.pretrained_embedding_matrix.part[0]
  96. .file_pattern,
  97. vocab=feature_spec.vocab.part[0].file_pattern,
  98. scaling_coefficient=1.0)
  99. return tf.get_variable(name, initializer=tf.reshape(embeddings, shape))
  100. else:
  101. return tf.get_variable(
  102. name,
  103. shape,
  104. initializer=tf.random_normal_initializer(
  105. stddev=1.0 / feature_spec.embedding_dim**.5, seed=seed))
  106. def embedding_lookup(embedding_matrix, indices, ids, weights, size):
  107. """Performs a weighted embedding lookup.
  108. Args:
  109. embedding_matrix: float Tensor from which to do the lookup.
  110. indices: int Tensor for the output rows of the looked up vectors.
  111. ids: int Tensor vectors to look up in the embedding_matrix.
  112. weights: float Tensor weights to apply to the looked up vectors.
  113. size: int number of output rows. Needed since some output rows may be
  114. empty.
  115. Returns:
  116. Weighted embedding vectors.
  117. """
  118. embeddings = tf.nn.embedding_lookup([embedding_matrix], ids)
  119. # TODO(googleuser): allow skipping weights.
  120. broadcast_weights_shape = tf.concat([tf.shape(weights), [1]], 0)
  121. embeddings *= tf.reshape(weights, broadcast_weights_shape)
  122. embeddings = tf.unsorted_segment_sum(embeddings, indices, size)
  123. return embeddings
  124. def fixed_feature_lookup(component, state, channel_id, stride):
  125. """Looks up fixed features and passes them through embeddings.
  126. Embedding vectors may be scaled by weights if the features specify it.
  127. Args:
  128. component: Component object in which to look up the fixed features.
  129. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  130. channel_id: int id of the fixed feature to look up.
  131. stride: int Tensor of current batch * beam size.
  132. Returns:
  133. NamedTensor object containing the embedding vectors.
  134. """
  135. feature_spec = component.spec.fixed_feature[channel_id]
  136. check.Gt(feature_spec.embedding_dim, 0,
  137. 'Embeddings requested for non-embedded feature: %s' % feature_spec)
  138. embedding_matrix = component.get_variable(fixed_embeddings_name(channel_id))
  139. with tf.op_scope([embedding_matrix], 'fixed_embedding_' + feature_spec.name):
  140. indices, ids, weights = dragnn_ops.extract_fixed_features(
  141. state.handle, component=component.name, channel_id=channel_id)
  142. size = stride * feature_spec.size
  143. embeddings = embedding_lookup(embedding_matrix, indices, ids, weights, size)
  144. dim = feature_spec.size * feature_spec.embedding_dim
  145. return NamedTensor(
  146. tf.reshape(embeddings, [-1, dim]), feature_spec.name, dim=dim)
  147. def get_input_tensor(fixed_embeddings, linked_embeddings):
  148. """Helper function for constructing an input tensor from all the features.
  149. Args:
  150. fixed_embeddings: list of NamedTensor objects for fixed feature channels
  151. linked_embeddings: list of NamedTensor objects for linked feature channels
  152. Returns:
  153. a tensor of shape [N, D], where D is the total input dimension of the
  154. concatenated feature channels
  155. Raises:
  156. RuntimeError: if no features, fixed or linked, are configured.
  157. """
  158. embeddings = fixed_embeddings + linked_embeddings
  159. if not embeddings:
  160. raise RuntimeError('There needs to be at least one feature set defined.')
  161. # Concat_v2 takes care of optimizing away the concatenation
  162. # operation in the case when there is exactly one embedding input.
  163. return tf.concat([e.tensor for e in embeddings], 1)
  164. def get_input_tensor_with_stride(fixed_embeddings, linked_embeddings, stride):
  165. """Constructs an input tensor with a separate dimension for steps.
  166. Args:
  167. fixed_embeddings: list of NamedTensor objects for fixed feature channels
  168. linked_embeddings: list of NamedTensor objects for linked feature channels
  169. stride: int stride (i.e. beam * batch) to use to reshape the input
  170. Returns:
  171. a tensor of shape [stride, num_steps, D], where D is the total input
  172. dimension of the concatenated feature channels
  173. """
  174. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  175. shape = tf.shape(input_tensor)
  176. return tf.reshape(input_tensor, [stride, -1, shape[1]])
  177. def convert_network_state_tensorarray(tensorarray):
  178. """Converts a source TensorArray to a source Tensor.
  179. Performs a permutation between the steps * [stride, D] shape of a
  180. source TensorArray and the (flattened) [stride * steps, D] shape of
  181. a source Tensor.
  182. The TensorArrays used during recurrence have an additional zeroth step that
  183. needs to be removed.
  184. Args:
  185. tensorarray: TensorArray object to be converted.
  186. Returns:
  187. Tensor object after conversion.
  188. """
  189. tensor = tensorarray.stack() # Results in a [steps, stride, D] tensor.
  190. tensor = tf.slice(tensor, [1, 0, 0], [-1, -1, -1]) # Lop off the 0th step.
  191. tensor = tf.transpose(tensor, [1, 0, 2]) # Switch steps and stride.
  192. return tf.reshape(tensor, [-1, tf.shape(tensor)[2]])
  193. def pass_through_embedding_matrix(act_block, embedding_matrix, step_idx):
  194. """Passes the activations through the embedding_matrix.
  195. Takes care to handle out of bounds lookups.
  196. Args:
  197. act_block: matrix of activations.
  198. embedding_matrix: matrix of weights.
  199. step_idx: vector containing step indices, with -1 indicating out of bounds.
  200. Returns:
  201. the embedded activations.
  202. """
  203. # Indicator vector for out of bounds lookups.
  204. step_idx_mask = tf.expand_dims(tf.equal(step_idx, -1), -1)
  205. # Pad the last column of the activation vectors with the indicator.
  206. act_block = tf.concat([act_block, tf.to_float(step_idx_mask)], 1)
  207. return tf.matmul(act_block, embedding_matrix)
  208. def lookup_named_tensor(name, named_tensors):
  209. """Retrieves a NamedTensor by name.
  210. Args:
  211. name: Name of the tensor to retrieve.
  212. named_tensors: List of NamedTensor objects to search.
  213. Returns:
  214. The NamedTensor in |named_tensors| with the |name|.
  215. Raises:
  216. KeyError: If the |name| is not found among the |named_tensors|.
  217. """
  218. for named_tensor in named_tensors:
  219. if named_tensor.name == name:
  220. return named_tensor
  221. raise KeyError('Name "%s" not found in named tensors: %s' %
  222. (name, named_tensors))
  223. def activation_lookup_recurrent(component, state, channel_id, source_array,
  224. source_layer_size, stride):
  225. """Looks up activations from tensor arrays.
  226. If the linked feature's embedding_dim is set to -1, the feature vectors are
  227. not passed through (i.e. multiplied by) an embedding matrix.
  228. Args:
  229. component: Component object in which to look up the fixed features.
  230. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  231. channel_id: int id of the fixed feature to look up.
  232. source_array: TensorArray from which to fetch feature vectors, expected to
  233. have size [steps + 1] elements of shape [stride, D] each.
  234. source_layer_size: int length of feature vectors before embedding.
  235. stride: int Tensor of current batch * beam size.
  236. Returns:
  237. NamedTensor object containing the embedding vectors.
  238. """
  239. feature_spec = component.spec.linked_feature[channel_id]
  240. with tf.name_scope('activation_lookup_recurrent_%s' % feature_spec.name):
  241. # Linked features are returned as a pair of tensors, one indexing into
  242. # steps, and one indexing within the activation tensor (beam x batch)
  243. # stored for a step.
  244. step_idx, idx = dragnn_ops.extract_link_features(
  245. state.handle, component=component.name, channel_id=channel_id)
  246. # We take the [steps, batch*beam, ...] tensor array, gather and concat
  247. # the steps we might need into a [some_steps*batch*beam, ...] tensor,
  248. # and flatten 'idx' to dereference this new tensor.
  249. #
  250. # The first element of each tensor array is reserved for an
  251. # initialization variable, so we offset all step indices by +1.
  252. #
  253. # TODO(googleuser): It would be great to not have to extract
  254. # the steps in their entirety, forcing a copy of much of the
  255. # TensorArray at each step. Better would be to support a
  256. # TensorArray.gather_nd to pick the specific elements directly.
  257. # TODO(googleuser): In the interim, a small optimization would
  258. # be to use tf.unique instead of tf.range.
  259. step_min = tf.reduce_min(step_idx)
  260. ta_range = tf.range(step_min + 1, tf.reduce_max(step_idx) + 2)
  261. act_block = source_array.gather(ta_range)
  262. act_block = tf.reshape(act_block,
  263. tf.concat([[-1], tf.shape(act_block)[2:]], 0))
  264. flat_idx = (step_idx - step_min) * stride + idx
  265. act_block = tf.gather(act_block, flat_idx)
  266. act_block = tf.reshape(act_block, [-1, source_layer_size])
  267. if feature_spec.embedding_dim != -1:
  268. embedding_matrix = component.get_variable(
  269. linked_embeddings_name(channel_id))
  270. act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
  271. step_idx)
  272. dim = feature_spec.size * feature_spec.embedding_dim
  273. else:
  274. # If embedding_dim is -1, just output concatenation of activations.
  275. dim = feature_spec.size * source_layer_size
  276. return NamedTensor(
  277. tf.reshape(act_block, [-1, dim]), feature_spec.name, dim=dim)
  278. def activation_lookup_other(component, state, channel_id, source_tensor,
  279. source_layer_size):
  280. """Looks up activations from tensors.
  281. If the linked feature's embedding_dim is set to -1, the feature vectors are
  282. not passed through (i.e. multiplied by) an embedding matrix.
  283. Args:
  284. component: Component object in which to look up the fixed features.
  285. state: MasterState object for the live nlp_saft::dragnn::MasterState.
  286. channel_id: int id of the fixed feature to look up.
  287. source_tensor: Tensor from which to fetch feature vectors. Expected to have
  288. have shape [steps + 1, stride, D].
  289. source_layer_size: int length of feature vectors before embedding (D). It
  290. would in principle be possible to get this dimension dynamically from
  291. the second dimension of source_tensor. However, having it statically is
  292. more convenient.
  293. Returns:
  294. NamedTensor object containing the embedding vectors.
  295. """
  296. feature_spec = component.spec.linked_feature[channel_id]
  297. with tf.name_scope('activation_lookup_other_%s' % feature_spec.name):
  298. # Linked features are returned as a pair of tensors, one indexing into
  299. # steps, and one indexing within the stride (beam x batch) of each step.
  300. step_idx, idx = dragnn_ops.extract_link_features(
  301. state.handle, component=component.name, channel_id=channel_id)
  302. # The first element of each tensor array is reserved for an
  303. # initialization variable, so we offset all step indices by +1.
  304. indices = tf.stack([step_idx + 1, idx], axis=1)
  305. act_block = tf.gather_nd(source_tensor, indices)
  306. act_block = tf.reshape(act_block, [-1, source_layer_size])
  307. if feature_spec.embedding_dim != -1:
  308. embedding_matrix = component.get_variable(
  309. linked_embeddings_name(channel_id))
  310. act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
  311. step_idx)
  312. dim = feature_spec.size * feature_spec.embedding_dim
  313. else:
  314. # If embedding_dim is -1, just output concatenation of activations.
  315. dim = feature_spec.size * source_layer_size
  316. return NamedTensor(
  317. tf.reshape(act_block, [-1, dim]), feature_spec.name, dim=dim)
  318. class LayerNorm(object):
  319. """Utility to add layer normalization to any tensor.
  320. Layer normalization implementation is based on:
  321. https://arxiv.org/abs/1607.06450. "Layer Normalization"
  322. Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
  323. This object will construct additional variables that need to be optimized, and
  324. these variables can be accessed via params().
  325. Attributes:
  326. params: List of additional parameters to be trained.
  327. """
  328. def __init__(self, component, name, shape, dtype):
  329. """Construct variables to normalize an input of given shape.
  330. Arguments:
  331. component: ComponentBuilder handle.
  332. name: Human readable name to organize the variables.
  333. shape: Shape of the layer to be normalized.
  334. dtype: Type of the layer to be normalized.
  335. """
  336. self._name = name
  337. self._shape = shape
  338. self._component = component
  339. beta = tf.get_variable(
  340. 'beta_%s' % name,
  341. shape=shape,
  342. dtype=dtype,
  343. initializer=tf.zeros_initializer())
  344. gamma = tf.get_variable(
  345. 'gamma_%s' % name,
  346. shape=shape,
  347. dtype=dtype,
  348. initializer=tf.ones_initializer())
  349. self._params = [beta, gamma]
  350. @property
  351. def params(self):
  352. return self._params
  353. def normalize(self, inputs):
  354. """Apply normalization to input.
  355. The shape must match the declared shape in the constructor.
  356. [This is copied from tf.contrib.rnn.LayerNormBasicLSTMCell.]
  357. Args:
  358. inputs: Input tensor
  359. Returns:
  360. Normalized version of input tensor.
  361. Raises:
  362. ValueError: if inputs has undefined rank.
  363. """
  364. inputs_shape = inputs.get_shape()
  365. inputs_rank = inputs_shape.ndims
  366. if inputs_rank is None:
  367. raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  368. axis = range(1, inputs_rank)
  369. beta = self._component.get_variable('beta_%s' % self._name)
  370. gamma = self._component.get_variable('gamma_%s' % self._name)
  371. with tf.variable_scope('layer_norm_%s' % self._name):
  372. # Calculate the moments on the last axis (layer activations).
  373. mean, variance = nn.moments(inputs, axis, keep_dims=True)
  374. # Compute layer normalization using the batch_normalization function.
  375. variance_epsilon = 1E-12
  376. outputs = nn.batch_normalization(
  377. inputs, mean, variance, beta, gamma, variance_epsilon)
  378. outputs.set_shape(inputs_shape)
  379. return outputs
  380. class Layer(object):
  381. """A layer in a feed-forward network.
  382. Attributes:
  383. component: ComponentBuilderBase that produces this layer.
  384. name: Name of this layer.
  385. dim: Dimension of this layer, or negative if dynamic.
  386. """
  387. def __init__(self, component, name, dim):
  388. check.NotNone(dim, 'Dimension is required')
  389. self.component = component
  390. self.name = name
  391. self.dim = dim
  392. def __str__(self):
  393. return 'Layer: %s/%s[%d]' % (self.component.name, self.name, self.dim)
  394. def create_array(self, stride):
  395. """Creates a new tensor array to store this layer's activations.
  396. Arguments:
  397. stride: Possibly dynamic batch * beam size with which to initialize the
  398. tensor array
  399. Returns:
  400. TensorArray object
  401. """
  402. check.Gt(self.dim, 0, 'Cannot create array when dimension is dynamic')
  403. tensor_array = ta.TensorArray(dtype=tf.float32,
  404. size=0,
  405. dynamic_size=True,
  406. clear_after_read=False,
  407. infer_shape=False,
  408. name='%s_array' % self.name)
  409. # Start each array with all zeros. Special values will still be learned via
  410. # the extra embedding dimension stored for each linked feature channel.
  411. initial_value = tf.zeros([stride, self.dim])
  412. return tensor_array.write(0, initial_value)
  413. def get_attrs_with_defaults(parameters, defaults):
  414. """Populates a dictionary with run-time attributes.
  415. Given defaults, populates any overrides from 'parameters' with their
  416. corresponding converted values. 'defaults' should be typed. This is useful
  417. for specifying NetworkUnit-specific configuration options.
  418. Args:
  419. parameters: a <string, string> map.
  420. defaults: a <string, value> typed set of default values.
  421. Returns:
  422. dictionary populated with any overrides.
  423. Raises:
  424. RuntimeError: if a key in parameters is not present in defaults.
  425. """
  426. attrs = defaults
  427. for key, value in parameters.iteritems():
  428. check.In(key, defaults, 'Unknown attribute: %s' % key)
  429. if isinstance(defaults[key], bool):
  430. attrs[key] = value.lower() == 'true'
  431. else:
  432. attrs[key] = type(defaults[key])(value)
  433. return attrs
  434. def maybe_apply_dropout(inputs, keep_prob, per_sequence, stride=None):
  435. """Applies dropout, if so configured, to an input tensor.
  436. The input may be rank 2 or 3 depending on whether the stride (i.e., batch
  437. size) has been incorporated into the shape.
  438. Args:
  439. inputs: [stride * num_steps, dim] or [stride, num_steps, dim] input tensor.
  440. keep_prob: Scalar probability of keeping each input element. If >= 1.0, no
  441. dropout is performed.
  442. per_sequence: If true, sample the dropout mask once per sequence, instead of
  443. once per step. Requires |stride| when true.
  444. stride: Scalar batch size. Optional if |per_sequence| is false.
  445. Returns:
  446. [stride * num_steps, dim] or [stride, num_steps, dim] tensor, matching the
  447. shape of |inputs|, containing the masked or original inputs, depending on
  448. whether dropout was actually performed.
  449. """
  450. check.Ge(inputs.get_shape().ndims, 2, 'inputs must be rank 2 or 3')
  451. check.Le(inputs.get_shape().ndims, 3, 'inputs must be rank 2 or 3')
  452. flat = (inputs.get_shape().ndims == 2)
  453. if keep_prob >= 1.0:
  454. return inputs
  455. if not per_sequence:
  456. return tf.nn.dropout(inputs, keep_prob)
  457. check.NotNone(stride, 'per-sequence dropout requires stride')
  458. dim = inputs.get_shape().as_list()[-1]
  459. check.NotNone(dim, 'inputs must have static activation dimension, but have '
  460. 'static shape %s' % inputs.get_shape().as_list())
  461. # If needed, restore the batch dimension to separate the sequences.
  462. inputs_sxnxd = tf.reshape(inputs, [stride, -1, dim]) if flat else inputs
  463. # Replace |num_steps| with 1 in |noise_shape|, so the dropout mask broadcasts
  464. # to all steps for a particular sequence.
  465. noise_shape = [stride, 1, dim]
  466. masked_sxnxd = tf.nn.dropout(inputs_sxnxd, keep_prob, noise_shape)
  467. # If needed, flatten out the batch dimension in the return value.
  468. return tf.reshape(masked_sxnxd, [-1, dim]) if flat else masked_sxnxd
  469. @registry.RegisteredClass
  470. class NetworkUnitInterface(object):
  471. """Base class to implement NN specifications.
  472. This class contains the required functionality to build a network inside of a
  473. DRAGNN graph: (1) initializing TF variables during __init__(), and (2)
  474. creating particular instances from extracted features in create().
  475. Attributes:
  476. params (list): List of tf.Variable objects representing trainable
  477. parameters.
  478. layers (list): List of Layer objects to track network layers that should
  479. be written to Tensors during training and inference.
  480. """
  481. __metaclass__ = ABCMeta # required for @abstractmethod
  482. def __init__(self, component, init_layers=None, init_context_layers=None):
  483. """Initializes parameters for embedding matrices.
  484. The subclass may provide optional lists of initial layers and context layers
  485. to allow this base class constructor to use accessors like get_layer_size(),
  486. which is required for networks that may be used self-recurrently.
  487. Args:
  488. component: parent ComponentBuilderBase object.
  489. init_layers: optional initial layers.
  490. init_context_layers: optional initial context layers.
  491. """
  492. self._seed = component.master.hyperparams.seed
  493. self._component = component
  494. self._params = []
  495. self._layers = init_layers if init_layers else []
  496. self._regularized_weights = []
  497. self._context_layers = init_context_layers if init_context_layers else []
  498. self._fixed_feature_dims = {} # mapping from name to dimension
  499. self._linked_feature_dims = {} # mapping from name to dimension
  500. # Allocate parameters for all embedding channels. Note that for both Fixed
  501. # and Linked embedding matrices, we store an additional +1 embedding that's
  502. # used when the index is out of scope.
  503. for channel_id, spec in enumerate(component.spec.fixed_feature):
  504. check.NotIn(spec.name, self._fixed_feature_dims,
  505. 'Duplicate fixed feature')
  506. check.Gt(spec.size, 0, 'Invalid fixed feature size')
  507. if spec.embedding_dim > 0:
  508. fixed_dim = spec.embedding_dim
  509. self._params.append(add_embeddings(channel_id, spec, self._seed))
  510. else:
  511. fixed_dim = 1 # assume feature ID extraction; only one ID per step
  512. self._fixed_feature_dims[spec.name] = spec.size * fixed_dim
  513. for channel_id, spec in enumerate(component.spec.linked_feature):
  514. check.NotIn(spec.name, self._linked_feature_dims,
  515. 'Duplicate linked feature')
  516. check.Gt(spec.size, 0, 'Invalid linked feature size')
  517. if spec.source_component == component.name:
  518. source_array_dim = self.get_layer_size(spec.source_layer)
  519. else:
  520. source = component.master.lookup_component[spec.source_component]
  521. source_array_dim = source.network.get_layer_size(spec.source_layer)
  522. if spec.embedding_dim != -1:
  523. check.Gt(source_array_dim, 0,
  524. 'Cannot embed linked feature with dynamic dimension')
  525. self._params.append(
  526. tf.get_variable(
  527. linked_embeddings_name(channel_id),
  528. [source_array_dim + 1, spec.embedding_dim],
  529. initializer=tf.random_normal_initializer(
  530. stddev=1 / spec.embedding_dim**.5, seed=self._seed)))
  531. self._linked_feature_dims[spec.name] = spec.size * spec.embedding_dim
  532. else:
  533. # If embedding_dim is -1, linked features are not embedded.
  534. self._linked_feature_dims[spec.name] = spec.size * source_array_dim
  535. # Compute the cumulative dimension of all inputs. If any input has dynamic
  536. # dimension, then the result is -1.
  537. input_dims = (self._fixed_feature_dims.values() +
  538. self._linked_feature_dims.values())
  539. if any(x < 0 for x in input_dims):
  540. self._concatenated_input_dim = -1
  541. else:
  542. self._concatenated_input_dim = sum(input_dims)
  543. tf.logging.info('component %s concat_input_dim %s', component.name,
  544. self._concatenated_input_dim)
  545. # Allocate attention parameters.
  546. if self._component.spec.attention_component:
  547. attention_source_component = self._component.master.lookup_component[
  548. self._component.spec.attention_component]
  549. attention_hidden_layer_sizes = map(
  550. int, attention_source_component.spec.network_unit.parameters[
  551. 'hidden_layer_sizes'].split(','))
  552. attention_hidden_layer_size = attention_hidden_layer_sizes[-1]
  553. hidden_layer_sizes = map(int, component.spec.network_unit.parameters[
  554. 'hidden_layer_sizes'].split(','))
  555. # The attention function is built on the last layer of hidden embeddings.
  556. hidden_layer_size = hidden_layer_sizes[-1]
  557. self._params.append(
  558. tf.get_variable(
  559. 'attention_weights_pm_0',
  560. [attention_hidden_layer_size, hidden_layer_size],
  561. initializer=tf.random_normal_initializer(
  562. stddev=1e-4, seed=self._seed)))
  563. self._params.append(
  564. tf.get_variable(
  565. 'attention_weights_hm_0', [hidden_layer_size, hidden_layer_size],
  566. initializer=tf.random_normal_initializer(
  567. stddev=1e-4, seed=self._seed)))
  568. self._params.append(
  569. tf.get_variable(
  570. 'attention_bias_0', [1, hidden_layer_size],
  571. initializer=tf.zeros_initializer()))
  572. self._params.append(
  573. tf.get_variable(
  574. 'attention_bias_1', [1, hidden_layer_size],
  575. initializer=tf.zeros_initializer()))
  576. self._params.append(
  577. tf.get_variable(
  578. 'attention_weights_pu',
  579. [attention_hidden_layer_size, component.num_actions],
  580. initializer=tf.random_normal_initializer(
  581. stddev=1e-4, seed=self._seed)))
  582. @abstractmethod
  583. def create(self,
  584. fixed_embeddings,
  585. linked_embeddings,
  586. context_tensor_arrays,
  587. attention_tensor,
  588. during_training,
  589. stride=None):
  590. """Constructs a feed-forward unit based on the features and context tensors.
  591. Args:
  592. fixed_embeddings: list of NamedTensor objects
  593. linked_embeddings: list of NamedTensor objects
  594. context_tensor_arrays: optional list of TensorArray objects used for
  595. implicit recurrence.
  596. attention_tensor: optional Tensor used for attention.
  597. during_training: whether to create a network for training (vs inference).
  598. stride: int scalar tensor containing the stride required for
  599. bulk computation.
  600. Returns:
  601. A list of tensors corresponding to the list of layers.
  602. """
  603. pass
  604. @property
  605. def layers(self):
  606. return self._layers
  607. @property
  608. def params(self):
  609. return self._params
  610. @property
  611. def regularized_weights(self):
  612. return self._regularized_weights
  613. @property
  614. def context_layers(self):
  615. return self._context_layers
  616. def get_layer_index(self, layer_name):
  617. """Gets the index of the given named layer of the network."""
  618. return [x.name for x in self.layers].index(layer_name)
  619. def get_layer_size(self, layer_name):
  620. """Gets the size of the given named layer of the network.
  621. Args:
  622. layer_name: string name of layer to look update
  623. Returns:
  624. the size of the layer.
  625. Raises:
  626. KeyError: if the layer_name to look up doesn't exist.
  627. """
  628. for layer in self.layers:
  629. if layer.name == layer_name:
  630. return layer.dim
  631. raise KeyError('Layer {} not found in component {}'.format(
  632. layer_name, self._component.name))
  633. def get_logits(self, network_tensors):
  634. """Pulls out the logits from the tensors produced by this unit.
  635. Args:
  636. network_tensors: list of tensors as output by create().
  637. Raises:
  638. NotImplementedError: by default a 'logits' tensor need not be implemented.
  639. """
  640. raise NotImplementedError()
  641. def get_l2_regularized_weights(self):
  642. """Gets the weights that need to be regularized."""
  643. return self.regularized_weights
  644. def attention(self, last_layer, attention_tensor):
  645. """Compute the attention term for the network unit."""
  646. h_tensor = attention_tensor
  647. # Compute the attentions.
  648. # Using feed-forward net to map the two inputs into the same dimension
  649. focus_tensor = tf.nn.tanh(
  650. tf.matmul(
  651. h_tensor,
  652. self._component.get_variable('attention_weights_pm_0'),
  653. name='h_x_pm') + self._component.get_variable('attention_bias_0'))
  654. context_tensor = tf.nn.tanh(
  655. tf.matmul(
  656. last_layer,
  657. self._component.get_variable('attention_weights_hm_0'),
  658. name='l_x_hm') + self._component.get_variable('attention_bias_1'))
  659. # The tf.multiply in the following expression broadcasts along the 0 dim:
  660. z_vec = tf.reduce_sum(tf.multiply(focus_tensor, context_tensor), 1)
  661. p_vec = tf.nn.softmax(tf.reshape(z_vec, [1, -1]))
  662. # The tf.multiply in the following expression broadcasts along the 1 dim:
  663. r_vec = tf.expand_dims(
  664. tf.reduce_sum(
  665. tf.multiply(
  666. h_tensor, tf.reshape(p_vec, [-1, 1]), name='time_together2'),
  667. 0),
  668. 0)
  669. return tf.matmul(
  670. r_vec,
  671. self._component.get_variable('attention_weights_pu'),
  672. name='time_together3')
  673. class IdentityNetwork(NetworkUnitInterface):
  674. """A network that returns concatenated input embeddings and activations."""
  675. def __init__(self, component):
  676. super(IdentityNetwork, self).__init__(component)
  677. self._layers = [
  678. Layer(
  679. component,
  680. name='input_embeddings',
  681. dim=self._concatenated_input_dim)
  682. ]
  683. def create(self,
  684. fixed_embeddings,
  685. linked_embeddings,
  686. context_tensor_arrays,
  687. attention_tensor,
  688. during_training,
  689. stride=None):
  690. return [get_input_tensor(fixed_embeddings, linked_embeddings)]
  691. def get_layer_size(self, layer_name):
  692. # Note that get_layer_size is called by super.__init__ before any layers are
  693. # constructed if and only if there are recurrent links.
  694. assert hasattr(self,
  695. '_layers'), 'IdentityNetwork cannot have recurrent links'
  696. return super(IdentityNetwork, self).get_layer_size(layer_name)
  697. def get_logits(self, network_tensors):
  698. return network_tensors[-1]
  699. def get_context_layers(self):
  700. return []
  701. class FeedForwardNetwork(NetworkUnitInterface):
  702. """Implementation of C&M style feedforward network.
  703. Supports dropout and optional layer normalization.
  704. Layers:
  705. layer_<i>: Activations for i'th hidden layer (0-origin).
  706. last_layer: Activations for the last hidden layer. This is a convenience
  707. alias for "layer_<n-1>", where n is the number of hidden layers.
  708. logits: Logits associated with component actions.
  709. """
  710. def __init__(self, component):
  711. """Initializes parameters required to run this network.
  712. Args:
  713. component: parent ComponentBuilderBase object.
  714. Parameters used to construct the network:
  715. hidden_layer_sizes: comma-separated list of ints, indicating the
  716. number of hidden units in each hidden layer.
  717. layer_norm_input (False): Whether or not to apply layer normalization
  718. on the concatenated input to the network.
  719. layer_norm_hidden (False): Whether or not to apply layer normalization
  720. to the first set of hidden layer activations.
  721. nonlinearity ('relu'): Name of function from module "tf.nn" to apply to
  722. each hidden layer; e.g., "relu" or "elu".
  723. dropout_keep_prob (-1.0): The probability that an input is not dropped.
  724. If >= 1.0, disables dropout. If < 0.0, uses the global |dropout_rate|
  725. hyperparameter.
  726. dropout_per_sequence (False): If true, sample the dropout mask once per
  727. sequence, instead of once per step. See Gal and Ghahramani
  728. (https://arxiv.org/abs/1512.05287).
  729. dropout_all_layers (False): If true, apply dropout to the input of all
  730. hidden layers, instead of just applying it to the network input.
  731. Hyperparameters used:
  732. dropout_rate: The probability that an input is not dropped. Only used
  733. when the |dropout_keep_prob| parameter is negative.
  734. """
  735. self._attrs = get_attrs_with_defaults(
  736. component.spec.network_unit.parameters, defaults={
  737. 'hidden_layer_sizes': '',
  738. 'layer_norm_input': False,
  739. 'layer_norm_hidden': False,
  740. 'nonlinearity': 'relu',
  741. 'dropout_keep_prob': -1.0,
  742. 'dropout_per_sequence': False,
  743. 'dropout_all_layers': False})
  744. # Initialize the hidden layer sizes before running the base initializer, as
  745. # the base initializer may need to know the size of of the hidden layer for
  746. # recurrent connections.
  747. self._hidden_layer_sizes = (
  748. map(int, self._attrs['hidden_layer_sizes'].split(','))
  749. if self._attrs['hidden_layer_sizes'] else [])
  750. super(FeedForwardNetwork, self).__init__(component)
  751. # Infer dropout rate from network parameters and grid hyperparameters.
  752. self._dropout_rate = self._attrs['dropout_keep_prob']
  753. if self._dropout_rate < 0.0:
  754. self._dropout_rate = component.master.hyperparams.dropout_rate
  755. # Add layer norm if specified.
  756. self._layer_norm_input = None
  757. self._layer_norm_hidden = None
  758. if self._attrs['layer_norm_input']:
  759. self._layer_norm_input = LayerNorm(self._component, 'concat_input',
  760. self._concatenated_input_dim,
  761. tf.float32)
  762. self._params.extend(self._layer_norm_input.params)
  763. if self._attrs['layer_norm_hidden']:
  764. self._layer_norm_hidden = LayerNorm(self._component, 'layer_0',
  765. self._hidden_layer_sizes[0],
  766. tf.float32)
  767. self._params.extend(self._layer_norm_hidden.params)
  768. # Extract nonlinearity from |tf.nn|.
  769. self._nonlinearity = getattr(tf.nn, self._attrs['nonlinearity'])
  770. # TODO(googleuser): add initializer stddevs as part of the network unit's
  771. # configuration.
  772. self._weights = []
  773. last_layer_dim = self._concatenated_input_dim
  774. # Initialize variables for the parameters, and add Layer objects for
  775. # cross-component bookkeeping.
  776. for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
  777. weights = tf.get_variable(
  778. 'weights_%d' % index, [last_layer_dim, hidden_layer_size],
  779. initializer=tf.random_normal_initializer(stddev=1e-4,
  780. seed=self._seed))
  781. self._params.append(weights)
  782. if index > 0 or self._layer_norm_hidden is None:
  783. self._params.append(
  784. tf.get_variable(
  785. 'bias_%d' % index, [hidden_layer_size],
  786. initializer=tf.constant_initializer(
  787. 0.2, dtype=tf.float32)))
  788. self._weights.append(weights)
  789. self._layers.append(
  790. Layer(
  791. component, name='layer_%d' % index, dim=hidden_layer_size))
  792. last_layer_dim = hidden_layer_size
  793. # Add a convenience alias for the last hidden layer, if any.
  794. if self._hidden_layer_sizes:
  795. self._layers.append(Layer(component, 'last_layer', last_layer_dim))
  796. # By default, regularize only the weights.
  797. self._regularized_weights.extend(self._weights)
  798. if component.num_actions:
  799. self._params.append(
  800. tf.get_variable(
  801. 'weights_softmax', [last_layer_dim, component.num_actions],
  802. initializer=tf.random_normal_initializer(
  803. stddev=1e-4, seed=self._seed)))
  804. self._params.append(
  805. tf.get_variable(
  806. 'bias_softmax', [component.num_actions],
  807. initializer=tf.zeros_initializer()))
  808. self._layers.append(
  809. Layer(
  810. component, name='logits', dim=component.num_actions))
  811. def create(self,
  812. fixed_embeddings,
  813. linked_embeddings,
  814. context_tensor_arrays,
  815. attention_tensor,
  816. during_training,
  817. stride=None):
  818. """See base class."""
  819. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  820. if during_training:
  821. input_tensor.set_shape([None, self._concatenated_input_dim])
  822. input_tensor = self._maybe_apply_dropout(input_tensor, stride)
  823. if self._layer_norm_input:
  824. input_tensor = self._layer_norm_input.normalize(input_tensor)
  825. tensors = []
  826. last_layer = input_tensor
  827. for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
  828. acts = tf.matmul(last_layer,
  829. self._component.get_variable('weights_%d' % index))
  830. # Note that the first layer was already handled before this loop.
  831. # TODO(googleuser): Refactor this loop so dropout and layer normalization
  832. # are applied consistently.
  833. if during_training and self._attrs['dropout_all_layers'] and index > 0:
  834. acts.set_shape([None, hidden_layer_size])
  835. acts = self._maybe_apply_dropout(acts, stride)
  836. # Don't add a bias term if we're going to apply layer norm, since layer
  837. # norm includes a bias already.
  838. if index == 0 and self._layer_norm_hidden:
  839. acts = self._layer_norm_hidden.normalize(acts)
  840. else:
  841. acts = tf.nn.bias_add(acts,
  842. self._component.get_variable('bias_%d' % index))
  843. last_layer = self._nonlinearity(acts)
  844. tensors.append(last_layer)
  845. # Add a convenience alias for the last hidden layer, if any.
  846. if self._hidden_layer_sizes:
  847. tensors.append(last_layer)
  848. if self._layers[-1].name == 'logits':
  849. logits = tf.matmul(
  850. last_layer, self._component.get_variable(
  851. 'weights_softmax')) + self._component.get_variable('bias_softmax')
  852. if self._component.spec.attention_component:
  853. logits += self.attention(last_layer, attention_tensor)
  854. logits = tf.identity(logits, name=self._layers[-1].name)
  855. tensors.append(logits)
  856. return tensors
  857. def get_layer_size(self, layer_name):
  858. if layer_name == 'logits':
  859. return self._component.num_actions
  860. if layer_name == 'last_layer':
  861. return self._hidden_layer_sizes[-1]
  862. if not layer_name.startswith('layer_'):
  863. logging.fatal(
  864. 'Invalid layer name: "%s" Can only retrieve from "logits", '
  865. '"last_layer", and "layer_*".',
  866. layer_name)
  867. # NOTE(danielandor): Since get_layer_size is called before the
  868. # model has been built, we compute the layer size directly from
  869. # the hyperparameters rather than from self._layers.
  870. layer_index = int(layer_name.split('_')[1])
  871. return self._hidden_layer_sizes[layer_index]
  872. def get_logits(self, network_tensors):
  873. return network_tensors[-1]
  874. def _maybe_apply_dropout(self, inputs, stride):
  875. return maybe_apply_dropout(inputs, self._dropout_rate,
  876. self._attrs['dropout_per_sequence'], stride)
  877. class LSTMNetwork(NetworkUnitInterface):
  878. """Implementation of action LSTM style network."""
  879. def __init__(self, component):
  880. assert component.num_actions > 0, 'Component num actions must be positive.'
  881. network_unit_spec = component.spec.network_unit
  882. self._hidden_layer_sizes = (
  883. int)(network_unit_spec.parameters['hidden_layer_sizes'])
  884. self._input_dropout_rate = component.master.hyperparams.dropout_rate
  885. self._recurrent_dropout_rate = (
  886. component.master.hyperparams.recurrent_dropout_rate)
  887. if self._recurrent_dropout_rate < 0.0:
  888. self._recurrent_dropout_rate = component.master.hyperparams.dropout_rate
  889. super(LSTMNetwork, self).__init__(component)
  890. layer_input_dim = self._concatenated_input_dim
  891. self._context_layers = []
  892. # TODO(googleuser): should we choose different initilizer,
  893. # e.g. truncated_normal_initializer?
  894. self._x2i = tf.get_variable(
  895. 'x2i', [layer_input_dim, self._hidden_layer_sizes],
  896. initializer=tf.random_normal_initializer(stddev=1e-4,
  897. seed=self._seed))
  898. self._h2i = tf.get_variable(
  899. 'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  900. initializer=tf.random_normal_initializer(stddev=1e-4,
  901. seed=self._seed))
  902. self._c2i = tf.get_variable(
  903. 'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  904. initializer=tf.random_normal_initializer(stddev=1e-4,
  905. seed=self._seed))
  906. self._bi = tf.get_variable(
  907. 'bi', [self._hidden_layer_sizes],
  908. initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
  909. self._x2o = tf.get_variable(
  910. 'x2o', [layer_input_dim, self._hidden_layer_sizes],
  911. initializer=tf.random_normal_initializer(stddev=1e-4,
  912. seed=self._seed))
  913. self._h2o = tf.get_variable(
  914. 'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  915. initializer=tf.random_normal_initializer(stddev=1e-4,
  916. seed=self._seed))
  917. self._c2o = tf.get_variable(
  918. 'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  919. initializer=tf.random_normal_initializer(stddev=1e-4,
  920. seed=self._seed))
  921. self._bo = tf.get_variable(
  922. 'bo', [self._hidden_layer_sizes],
  923. initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
  924. self._x2c = tf.get_variable(
  925. 'x2c', [layer_input_dim, self._hidden_layer_sizes],
  926. initializer=tf.random_normal_initializer(stddev=1e-4,
  927. seed=self._seed))
  928. self._h2c = tf.get_variable(
  929. 'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes],
  930. initializer=tf.random_normal_initializer(stddev=1e-4,
  931. seed=self._seed))
  932. self._bc = tf.get_variable(
  933. 'bc', [self._hidden_layer_sizes],
  934. initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
  935. self._params.extend([
  936. self._x2i, self._h2i, self._c2i, self._bi, self._x2o, self._h2o,
  937. self._c2o, self._bo, self._x2c, self._h2c, self._bc])
  938. lstm_h_layer = Layer(component, name='lstm_h', dim=self._hidden_layer_sizes)
  939. lstm_c_layer = Layer(component, name='lstm_c', dim=self._hidden_layer_sizes)
  940. self._context_layers.append(lstm_h_layer)
  941. self._context_layers.append(lstm_c_layer)
  942. self._layers.extend(self._context_layers)
  943. self._layers.append(
  944. Layer(
  945. component, name='layer_0', dim=self._hidden_layer_sizes))
  946. self.params.append(tf.get_variable(
  947. 'weights_softmax', [self._hidden_layer_sizes, component.num_actions],
  948. initializer=tf.random_normal_initializer(stddev=1e-4,
  949. seed=self._seed)))
  950. self.params.append(
  951. tf.get_variable(
  952. 'bias_softmax', [component.num_actions],
  953. initializer=tf.zeros_initializer()))
  954. self._layers.append(
  955. Layer(
  956. component, name='logits', dim=component.num_actions))
  957. def create(self,
  958. fixed_embeddings,
  959. linked_embeddings,
  960. context_tensor_arrays,
  961. attention_tensor,
  962. during_training,
  963. stride=None):
  964. """See base class."""
  965. input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
  966. # context_tensor_arrays[0] is lstm_h
  967. # context_tensor_arrays[1] is lstm_c
  968. assert len(context_tensor_arrays) == 2
  969. length = context_tensor_arrays[0].size()
  970. # Get the (possibly averaged) parameters to execute the network.
  971. x2i = self._component.get_variable('x2i')
  972. h2i = self._component.get_variable('h2i')
  973. c2i = self._component.get_variable('c2i')
  974. bi = self._component.get_variable('bi')
  975. x2o = self._component.get_variable('x2o')
  976. h2o = self._component.get_variable('h2o')
  977. c2o = self._component.get_variable('c2o')
  978. bo = self._component.get_variable('bo')
  979. x2c = self._component.get_variable('x2c')
  980. h2c = self._component.get_variable('h2c')
  981. bc = self._component.get_variable('bc')
  982. # i_h_tm1, i_c_tm1 = h_{t-1}, c_{t-1}
  983. i_h_tm1 = context_tensor_arrays[0].read(length - 1)
  984. i_c_tm1 = context_tensor_arrays[1].read(length - 1)
  985. # apply dropout according to http://arxiv.org/pdf/1409.2329v5.pdf
  986. if during_training and self._input_dropout_rate < 1:
  987. input_tensor = tf.nn.dropout(input_tensor, self._input_dropout_rate)
  988. # input -- i_t = sigmoid(affine(x_t, h_{t-1}, c_{t-1}))
  989. i_ait = tf.matmul(input_tensor, x2i) + tf.matmul(i_h_tm1, h2i) + tf.matmul(
  990. i_c_tm1, c2i) + bi
  991. i_it = tf.sigmoid(i_ait)
  992. # forget -- f_t = 1 - i_t
  993. i_ft = tf.ones([1, 1]) - i_it
  994. # write memory cell -- tanh(affine(x_t, h_{t-1}))
  995. i_awt = tf.matmul(input_tensor, x2c) + tf.matmul(i_h_tm1, h2c) + bc
  996. i_wt = tf.tanh(i_awt)
  997. # c_t = f_t \odot c_{t-1} + i_t \odot tanh(affine(x_t, h_{t-1}))
  998. ct = tf.add(
  999. tf.multiply(i_it, i_wt), tf.multiply(i_ft, i_c_tm1), name='lstm_c')
  1000. # output -- o_t = sigmoid(affine(x_t, h_{t-1}, c_t))
  1001. i_aot = tf.matmul(input_tensor, x2o) + tf.matmul(ct, c2o) + tf.matmul(
  1002. i_h_tm1, h2o) + bo
  1003. i_ot = tf.sigmoid(i_aot)
  1004. # ht = o_t \odot tanh(ct)
  1005. ph_t = tf.tanh(ct)
  1006. ht = tf.multiply(i_ot, ph_t, name='lstm_h')
  1007. if during_training and self._recurrent_dropout_rate < 1:
  1008. ht = tf.nn.dropout(
  1009. ht, self._recurrent_dropout_rate, name='lstm_h_dropout')
  1010. h = tf.identity(ht, name='layer_0')
  1011. logits = tf.nn.xw_plus_b(ht, tf.get_variable('weights_softmax'),
  1012. tf.get_variable('bias_softmax'))
  1013. if self._component.spec.attention_component:
  1014. logits += self.attention(ht, attention_tensor)
  1015. logits = tf.identity(logits, name='logits')
  1016. # tensors will be consistent with the layers:
  1017. # [lstm_h, lstm_c, layer_0, logits]
  1018. tensors = [ht, ct, h, logits]
  1019. return tensors
  1020. def get_layer_size(self, layer_name):
  1021. assert layer_name == 'layer_0', 'Can only retrieve from first hidden layer.'
  1022. return self._hidden_layer_sizes
  1023. def get_logits(self, network_tensors):
  1024. return network_tensors[self.get_layer_index('logits')]
  1025. class ConvNetwork(NetworkUnitInterface):
  1026. """Implementation of a convolutional feed forward network."""
  1027. def __init__(self, component):
  1028. """Initializes kernels and biases for this convolutional net.
  1029. Args:
  1030. component: parent ComponentBuilderBase object.
  1031. Parameters used to construct the network:
  1032. widths: comma separated list of ints, number of steps input to the
  1033. convolutional kernel at every layer.
  1034. depths: comma separated list of ints, number of channels input to the
  1035. convolutional kernel at every layer.
  1036. output_embedding_dim: int, number of output channels for the convolutional
  1037. kernel of the last layer, which receives no ReLU activation and
  1038. therefore can be used in a softmax output. If zero, this final
  1039. layer is disabled entirely.
  1040. nonlinearity ('relu'): Name of function from module "tf.nn" to apply to
  1041. each hidden layer; e.g., "relu" or "elu".
  1042. dropout_keep_prob (-1.0): The probability that an input is not dropped.
  1043. If >= 1.0, disables dropout. If < 0.0, uses the global |dropout_rate|
  1044. hyperparameter.
  1045. dropout_per_sequence (False): If true, sample the dropout mask once per
  1046. sequence, instead of once per step. See Gal and Ghahramani
  1047. (https://arxiv.org/abs/1512.05287).
  1048. Hyperparameters used:
  1049. dropout_rate: The probability that an input is not dropped. Only used
  1050. when the |dropout_keep_prob| parameter is negative.
  1051. """
  1052. super(ConvNetwork, self).__init__(component)
  1053. self._attrs = get_attrs_with_defaults(
  1054. component.spec.network_unit.parameters, defaults={
  1055. 'widths': '',
  1056. 'depths': '',
  1057. 'output_embedding_dim': 0,
  1058. 'nonlinearity': 'relu',
  1059. 'dropout_keep_prob': -1.0,
  1060. 'dropout_per_sequence': False})
  1061. self._weights = []
  1062. self._biases = []
  1063. self._widths = map(int, self._attrs['widths'].split(','))
  1064. self._depths = map(int, self._attrs['depths'].split(','))
  1065. self._output_dim = self._attrs['output_embedding_dim']
  1066. if self._output_dim:
  1067. self._depths.append(self._output_dim)
  1068. self.kernel_shapes = []
  1069. for i in range(len(self._depths) - 1):
  1070. self.kernel_shapes.append(
  1071. [1, self._widths[i], self._depths[i], self._depths[i + 1]])
  1072. for i in range(len(self._depths) - 1):
  1073. with tf.variable_scope('conv%d' % i):
  1074. self._weights.append(
  1075. tf.get_variable(
  1076. 'weights',
  1077. self.kernel_shapes[i],
  1078. initializer=tf.random_normal_initializer(
  1079. stddev=1e-4, seed=self._seed),
  1080. dtype=tf.float32))
  1081. bias_init = 0.0 if (i == len(self._widths) - 1) else 0.2
  1082. self._biases.append(
  1083. tf.get_variable(
  1084. 'biases',
  1085. self.kernel_shapes[i][-1],
  1086. initializer=tf.constant_initializer(bias_init),
  1087. dtype=tf.float32))
  1088. # Extract nonlinearity from |tf.nn|.
  1089. self._nonlinearity = getattr(tf.nn, self._attrs['nonlinearity'])
  1090. # Infer dropout rate from network parameters and grid hyperparameters.
  1091. self._dropout_rate = self._attrs['dropout_keep_prob']
  1092. if self._dropout_rate < 0.0:
  1093. self._dropout_rate = component.master.hyperparams.dropout_rate
  1094. self._params.extend(self._weights + self._biases)
  1095. self._layers.append(
  1096. Layer(
  1097. component, name='conv_output', dim=self._depths[-1]))
  1098. self._regularized_weights.extend(self._weights[:-1] if self._output_dim else
  1099. self._weights)
  1100. def create(self,
  1101. fixed_embeddings,
  1102. linked_embeddings,
  1103. context_tensor_arrays,
  1104. attention_tensor,
  1105. during_training,
  1106. stride=None):
  1107. """Requires |stride|; otherwise see base class."""
  1108. if stride is None:
  1109. raise RuntimeError("ConvNetwork needs 'stride' and must be called in the "
  1110. "bulk feature extractor component.")
  1111. input_tensor = get_input_tensor_with_stride(fixed_embeddings,
  1112. linked_embeddings, stride)
  1113. # TODO(googleuser): Add context and attention.
  1114. del context_tensor_arrays, attention_tensor
  1115. # On CPU, add a dimension so that the 'image' has shape
  1116. # [stride, 1, num_steps, D].
  1117. conv = tf.expand_dims(input_tensor, 1)
  1118. for i in range(len(self._depths) - 1):
  1119. with tf.variable_scope('conv%d' % i, reuse=True) as scope:
  1120. if during_training:
  1121. conv.set_shape([None, 1, None, self._depths[i]])
  1122. conv = self._maybe_apply_dropout(conv, stride)
  1123. conv = tf.nn.conv2d(
  1124. conv,
  1125. self._component.get_variable('weights'), [1, 1, 1, 1],
  1126. padding='SAME')
  1127. conv = tf.nn.bias_add(conv, self._component.get_variable('biases'))
  1128. if i < (len(self._weights) - 1) or not self._output_dim:
  1129. conv = self._nonlinearity(conv, name=scope.name)
  1130. return [
  1131. tf.reshape(
  1132. conv, [-1, self._depths[-1]], name='reshape_activations')
  1133. ]
  1134. def _maybe_apply_dropout(self, inputs, stride):
  1135. # The |inputs| are rank 4 (one 1xN "image" per sequence). Squeeze out and
  1136. # restore the singleton image height, so dropout is applied to the normal
  1137. # rank 3 batched input tensor.
  1138. inputs = tf.squeeze(inputs, [1])
  1139. inputs = maybe_apply_dropout(inputs, self._dropout_rate,
  1140. self._attrs['dropout_per_sequence'], stride)
  1141. inputs = tf.expand_dims(inputs, 1)
  1142. return inputs
  1143. class PairwiseConvNetwork(NetworkUnitInterface):
  1144. """Implementation of a pairwise 2D convolutional feed forward network.
  1145. For a sequence of N tokens, all N^2 pairs of concatenated input features are
  1146. constructed. If each input vector is of length D, then the sequence is
  1147. represented by an image of dimensions [N, N] with 2*D channels per pixel.
  1148. I.e. pixel [i, j] has a representation that is the concatenation of the
  1149. representations of the tokens at i and at j.
  1150. To use this network for graph edge scoring, for instance by using the "heads"
  1151. transition system, the output layer needs to have dimensions [N, N] and only
  1152. a single channel. The network takes care of outputting an [N, N] sized layer,
  1153. but the user needs to ensure that the output depth equals 1.
  1154. TODO(googleuser): Like Dozat and Manning, we will need an
  1155. additional network to label the edges, and the ability to read head
  1156. and modifier representations from different inputs.
  1157. """
  1158. def __init__(self, component):
  1159. """Initializes kernels and biases for this convolutional net.
  1160. Parameters used to construct the network:
  1161. depths: comma separated list of ints, number of channels input to the
  1162. convolutional kernel at every layer.
  1163. widths: comma separated list of ints, number of steps input to the
  1164. convolutional kernel at every layer.
  1165. relu_layers: comma separate list of ints, the id of layers after which
  1166. to apply a relu activation. *By default, all but the final layer will
  1167. have a relu activation applied.*
  1168. To generate a network with M layers, both 'depths' and 'widths' must be of
  1169. length M. The input depth of the first layer is inferred from the total
  1170. concatenated size of the input features.
  1171. Args:
  1172. component: parent ComponentBuilderBase object.
  1173. Raises:
  1174. RuntimeError: if the number of depths and weights are not equal.
  1175. ValueError: if the final depth is not equal to 1.
  1176. """
  1177. parameters = component.spec.network_unit.parameters
  1178. super(PairwiseConvNetwork, self).__init__(component)
  1179. # Each input pixel will comprise the concatenation of two tokens, so the
  1180. # input depth is double that for a single token.
  1181. self._depths = [self._concatenated_input_dim * 2]
  1182. self._depths.extend(map(int, parameters['depths'].split(',')))
  1183. self._widths = map(int, parameters['widths'].split(','))
  1184. self._num_layers = len(self._widths)
  1185. if len(self._depths) != self._num_layers + 1:
  1186. raise RuntimeError('Unmatched depths/weights %s/%s' %
  1187. (parameters['depths'], parameters['weights']))
  1188. if self._depths[-1] != 1:
  1189. raise ValueError('Final depth is not equal to 1 in %s' %
  1190. parameters['depths'])
  1191. self._kernel_shapes = []
  1192. for i, width in enumerate(self._widths):
  1193. self._kernel_shapes.append(
  1194. [width, width, self._depths[i], self._depths[i + 1]])
  1195. if parameters['relu_layers']:
  1196. self._relu_layers = set(map(int, parameters['relu_layers'].split(',')))
  1197. else:
  1198. self._relu_layers = set(range(self._num_layers - 1))
  1199. self._weights = []
  1200. self._biases = []
  1201. for i, kernel_shape in enumerate(self._kernel_shapes):
  1202. with tf.variable_scope('conv%d' % i):
  1203. self._weights.append(
  1204. tf.get_variable(
  1205. 'weights',
  1206. kernel_shape,
  1207. initializer=tf.random_normal_initializer(
  1208. stddev=1e-4, seed=self._seed),
  1209. dtype=tf.float32))
  1210. bias_init = 0.0 if i in self._relu_layers else 0.2
  1211. self._biases.append(
  1212. tf.get_variable(
  1213. 'biases',
  1214. kernel_shape[-1],
  1215. initializer=tf.constant_initializer(bias_init),
  1216. dtype=tf.float32))
  1217. self._params.extend(self._weights + self._biases)
  1218. self._layers.append(Layer(component, name='conv_output', dim=-1))
  1219. self._regularized_weights.extend(self._weights[:-1])
  1220. def create(self,
  1221. fixed_embeddings,
  1222. linked_embeddings,
  1223. context_tensor_arrays,
  1224. attention_tensor,
  1225. during_training,
  1226. stride=None):
  1227. """Requires |stride|; otherwise see base class."""
  1228. # TODO(googleuser): Normalize the arguments to create(). 'stride'
  1229. # is unused by the recurrent network units, while 'context_tensor_arrays'
  1230. # and 'attenion_tensor_array' is unused by bulk network units. b/33587044
  1231. if stride is None:
  1232. raise ValueError("PairwiseConvNetwork needs 'stride'")
  1233. input_tensor = get_input_tensor_with_stride(fixed_embeddings,
  1234. linked_embeddings, stride)
  1235. # TODO(googleuser): Add dropout.
  1236. del context_tensor_arrays, attention_tensor, during_training # Unused.
  1237. num_steps = tf.shape(input_tensor)[1]
  1238. arg1 = tf.expand_dims(input_tensor, 1)
  1239. arg1 = tf.tile(arg1, tf.stack([1, num_steps, 1, 1]))
  1240. arg2 = tf.expand_dims(input_tensor, 2)
  1241. arg2 = tf.tile(arg2, tf.stack([1, 1, num_steps, 1]))
  1242. conv = tf.concat([arg1, arg2], 3)
  1243. for i in xrange(self._num_layers):
  1244. with tf.variable_scope('conv%d' % i, reuse=True) as scope:
  1245. conv = tf.nn.conv2d(
  1246. conv,
  1247. self._component.get_variable('weights'), [1, 1, 1, 1],
  1248. padding='SAME')
  1249. conv = tf.nn.bias_add(conv, self._component.get_variable('biases'))
  1250. if i in self._relu_layers:
  1251. conv = tf.nn.relu(conv, name=scope.name)
  1252. return [tf.reshape(conv, [-1, num_steps], name='reshape_activations')]
  1253. class ExportFixedFeaturesNetwork(NetworkUnitInterface):
  1254. """A network that exports fixed features as layers.
  1255. Each fixed feature embedding is output as a layer whose name and dimension are
  1256. set to the name and dimension of the corresponding fixed feature.
  1257. """
  1258. def __init__(self, component):
  1259. """Initializes exported layers."""
  1260. super(ExportFixedFeaturesNetwork, self).__init__(component)
  1261. for feature_spec in component.spec.fixed_feature:
  1262. name = feature_spec.name
  1263. dim = self._fixed_feature_dims[name]
  1264. self._layers.append(Layer(component, name, dim))
  1265. def create(self,
  1266. fixed_embeddings,
  1267. linked_embeddings,
  1268. context_tensor_arrays,
  1269. attention_tensor,
  1270. during_training,
  1271. stride=None):
  1272. """See base class."""
  1273. check.Eq(len(self.layers), len(fixed_embeddings))
  1274. for index in range(len(fixed_embeddings)):
  1275. check.Eq(self.layers[index].name, fixed_embeddings[index].name)
  1276. return [fixed_embedding.tensor for fixed_embedding in fixed_embeddings]
  1277. class SplitNetwork(NetworkUnitInterface):
  1278. """Network unit that splits its input into slices of equal dimension.
  1279. Parameters:
  1280. num_slices: The number of slices to split the input into, S. The input must
  1281. have static dimension D, where D % S == 0.
  1282. Features:
  1283. All inputs are concatenated before being split.
  1284. Layers:
  1285. slice_0: [B * N, D / S] The first slice of the input.
  1286. slice_1: [B * N, D / S] The second slice of the input.
  1287. ...
  1288. """
  1289. def __init__(self, component):
  1290. """Initializes weights and layers.
  1291. Args:
  1292. component: Parent ComponentBuilderBase object.
  1293. """
  1294. super(SplitNetwork, self).__init__(component)
  1295. parameters = component.spec.network_unit.parameters
  1296. self._num_slices = int(parameters['num_slices'])
  1297. check.Gt(self._num_slices, 0, 'Invalid number of slices.')
  1298. check.Eq(self._concatenated_input_dim % self._num_slices, 0,
  1299. 'Input dimension %s does not evenly divide into %s slices' %
  1300. (self._concatenated_input_dim, self._num_slices))
  1301. self._slice_dim = int(self._concatenated_input_dim / self._num_slices)
  1302. for slice_index in xrange(self._num_slices):
  1303. self._layers.append(
  1304. Layer(self, 'slice_%s' % slice_index, self._slice_dim))
  1305. def create(self,
  1306. fixed_embeddings,
  1307. linked_embeddings,
  1308. context_tensor_arrays,
  1309. attention_tensor,
  1310. during_training,
  1311. stride=None):
  1312. input_bnxd = get_input_tensor(fixed_embeddings, linked_embeddings)
  1313. return tf.split(input_bnxd, self._num_slices, axis=1)