wrapped_units.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. """Network units wrapping TensorFlows' tf.contrib.rnn cells.
  2. Please put all wrapping logic for tf.contrib.rnn in this module; this will help
  3. collect common subroutines that prove useful.
  4. """
  5. import abc
  6. import tensorflow as tf
  7. from dragnn.python import network_units as dragnn
  8. from syntaxnet.util import check
  9. class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
  10. """Base class for wrapped LSTM networks.
  11. This LSTM network unit supports multiple layers with layer normalization.
  12. Because it is imported from tf.contrib.rnn, we need to capture the created
  13. variables during initialization time.
  14. Layers:
  15. ...subclass-specific layers...
  16. last_layer: Alias for the activations of the last hidden layer.
  17. logits: Logits associated with component actions.
  18. """
  19. def __init__(self, component):
  20. """Initializes the LSTM base class.
  21. Parameters used:
  22. hidden_layer_sizes: Comma-delimited number of hidden units for each layer.
  23. input_dropout_rate (-1.0): Input dropout rate for each layer. If < 0.0,
  24. use the global |dropout_rate| hyperparameter.
  25. recurrent_dropout_rate (0.8): Recurrent dropout rate. If < 0.0, use the
  26. global |recurrent_dropout_rate| hyperparameter.
  27. layer_norm (True): Whether or not to use layer norm.
  28. Hyperparameters used:
  29. dropout_rate: Input dropout rate.
  30. recurrent_dropout_rate: Recurrent dropout rate.
  31. Args:
  32. component: parent ComponentBuilderBase object.
  33. """
  34. self._attrs = dragnn.get_attrs_with_defaults(
  35. component.spec.network_unit.parameters,
  36. defaults={
  37. 'layer_norm': True,
  38. 'input_dropout_rate': -1.0,
  39. 'recurrent_dropout_rate': 0.8,
  40. 'hidden_layer_sizes': '256',
  41. })
  42. self._hidden_layer_sizes = map(int,
  43. self._attrs['hidden_layer_sizes'].split(','))
  44. self._input_dropout_rate = self._attrs['input_dropout_rate']
  45. if self._input_dropout_rate < 0.0:
  46. self._input_dropout_rate = component.master.hyperparams.dropout_rate
  47. self._recurrent_dropout_rate = self._attrs['recurrent_dropout_rate']
  48. if self._recurrent_dropout_rate < 0.0:
  49. self._recurrent_dropout_rate = (
  50. component.master.hyperparams.recurrent_dropout_rate)
  51. if self._recurrent_dropout_rate < 0.0:
  52. self._recurrent_dropout_rate = component.master.hyperparams.dropout_rate
  53. tf.logging.info('[%s] input_dropout_rate=%s recurrent_dropout_rate=%s',
  54. component.name, self._input_dropout_rate,
  55. self._recurrent_dropout_rate)
  56. layers, context_layers = self.create_hidden_layers(component,
  57. self._hidden_layer_sizes)
  58. last_layer_dim = layers[-1].dim
  59. layers.append(
  60. dragnn.Layer(component, name='last_layer', dim=last_layer_dim))
  61. layers.append(
  62. dragnn.Layer(component, name='logits', dim=component.num_actions))
  63. # Provide initial layers and context layers, so the base class constructor
  64. # can safely use accessors like get_layer_size().
  65. super(BaseLSTMNetwork, self).__init__(
  66. component, init_layers=layers, init_context_layers=context_layers)
  67. # Allocate parameters for the softmax.
  68. self._params.append(
  69. tf.get_variable(
  70. 'weights_softmax', [last_layer_dim, component.num_actions],
  71. initializer=tf.random_normal_initializer(
  72. stddev=1e-4, seed=self._seed)))
  73. self._params.append(
  74. tf.get_variable(
  75. 'bias_softmax', [component.num_actions],
  76. initializer=tf.zeros_initializer()))
  77. def get_logits(self, network_tensors):
  78. """Returns the logits for prediction."""
  79. return network_tensors[self.get_layer_index('logits')]
  80. @abc.abstractmethod
  81. def create_hidden_layers(self, component, hidden_layer_sizes):
  82. """Creates hidden network layers.
  83. Args:
  84. component: Parent ComponentBuilderBase object.
  85. hidden_layer_sizes: List of requested hidden layer activation sizes.
  86. Returns:
  87. layers: List of layers created by this network.
  88. context_layers: List of context layers created by this network.
  89. """
  90. pass
  91. def _append_base_layers(self, hidden_layers):
  92. """Appends layers defined by the base class to the |hidden_layers|."""
  93. last_layer = hidden_layers[-1]
  94. # TODO(googleuser): Uncomment the version that uses component.get_variable()
  95. # and delete the uses of tf.get_variable().
  96. # logits = tf.nn.xw_plus_b(last_layer,
  97. # self._component.get_variable('weights_softmax'),
  98. # self._component.get_variable('bias_softmax'))
  99. logits = tf.nn.xw_plus_b(last_layer,
  100. tf.get_variable('weights_softmax'),
  101. tf.get_variable('bias_softmax'))
  102. return hidden_layers + [last_layer, logits]
  103. def _create_cell(self, num_units, during_training):
  104. """Creates a single LSTM cell, possibly with dropout.
  105. Requires that BaseLSTMNetwork.__init__() was called.
  106. Args:
  107. num_units: Number of hidden units in the cell.
  108. during_training: Whether to create a cell for training (vs inference).
  109. Returns:
  110. A RNNCell of the requested size, possibly with dropout.
  111. """
  112. # No dropout in inference mode.
  113. if not during_training:
  114. return tf.contrib.rnn.LayerNormBasicLSTMCell(
  115. num_units, layer_norm=self._attrs['layer_norm'], reuse=True)
  116. # Otherwise, apply dropout to inputs and recurrences.
  117. cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
  118. num_units,
  119. dropout_keep_prob=self._recurrent_dropout_rate,
  120. layer_norm=self._attrs['layer_norm'])
  121. cell = tf.contrib.rnn.DropoutWrapper(
  122. cell, input_keep_prob=self._input_dropout_rate)
  123. return cell
  124. def _create_train_cells(self):
  125. """Creates a list of LSTM cells for training."""
  126. return [
  127. self._create_cell(num_units, during_training=True)
  128. for num_units in self._hidden_layer_sizes
  129. ]
  130. def _create_inference_cells(self):
  131. """Creates a list of LSTM cells for inference."""
  132. return [
  133. self._create_cell(num_units, during_training=False)
  134. for num_units in self._hidden_layer_sizes
  135. ]
  136. def _capture_variables_as_params(self, function):
  137. """Captures variables created by a function in |self._params|.
  138. Args:
  139. function: Function whose variables should be captured. The function
  140. should take one argument, its enclosing variable scope.
  141. """
  142. created_vars = {}
  143. def _custom_getter(getter, *args, **kwargs):
  144. """Calls the real getter and captures its result in |created_vars|."""
  145. real_variable = getter(*args, **kwargs)
  146. created_vars[real_variable.name] = real_variable
  147. return real_variable
  148. with tf.variable_scope(
  149. 'cell', reuse=None, custom_getter=_custom_getter) as scope:
  150. function(scope)
  151. self._params.extend(created_vars.values())
  152. def _apply_with_captured_variables(self, function):
  153. """Applies a function using previously-captured variables.
  154. Args:
  155. function: Function to apply using captured variables. The function
  156. should take one argument, its enclosing variable scope.
  157. Returns:
  158. Results of function application.
  159. """
  160. def _custom_getter(getter, *args, **kwargs):
  161. """Retrieves the normal or moving-average variables."""
  162. return self._component.get_variable(var_params=getter(*args, **kwargs))
  163. with tf.variable_scope(
  164. 'cell', reuse=True, custom_getter=_custom_getter) as scope:
  165. return function(scope)
  166. class LayerNormBasicLSTMNetwork(BaseLSTMNetwork):
  167. """Wrapper around tf.contrib.rnn.LayerNormBasicLSTMCell.
  168. Features:
  169. All inputs are concatenated.
  170. Subclass-specific layers:
  171. state_c_<n>: Cell states for the <n>'th LSTM layer (0-origin).
  172. state_h_<n>: Hidden states for the <n>'th LSTM layer (0-origin).
  173. """
  174. def __init__(self, component):
  175. """Sets up context and output layers, as well as a final softmax."""
  176. super(LayerNormBasicLSTMNetwork, self).__init__(component)
  177. # Wrap lists of training and inference sub-cells into multi-layer RNN cells.
  178. # Note that a |MultiRNNCell| state is a tuple of per-layer sub-states.
  179. self._train_cell = tf.contrib.rnn.MultiRNNCell(self._create_train_cells())
  180. self._inference_cell = tf.contrib.rnn.MultiRNNCell(
  181. self._create_inference_cells())
  182. def _cell_closure(scope):
  183. """Applies the LSTM cell to placeholder inputs and state."""
  184. placeholder_inputs = tf.placeholder(
  185. dtype=tf.float32, shape=(1, self._concatenated_input_dim))
  186. placeholder_substates = []
  187. for num_units in self._hidden_layer_sizes:
  188. placeholder_substate = tf.contrib.rnn.LSTMStateTuple(
  189. tf.placeholder(dtype=tf.float32, shape=(1, num_units)),
  190. tf.placeholder(dtype=tf.float32, shape=(1, num_units)))
  191. placeholder_substates.append(placeholder_substate)
  192. placeholder_state = tuple(placeholder_substates)
  193. self._train_cell(
  194. inputs=placeholder_inputs, state=placeholder_state, scope=scope)
  195. self._capture_variables_as_params(_cell_closure)
  196. def create_hidden_layers(self, component, hidden_layer_sizes):
  197. """See base class."""
  198. # Construct the layer meta info for the DRAGNN builder. Note that the order
  199. # of h and c are reversed compared to the vanilla DRAGNN LSTM cell, as
  200. # this is the standard in tf.contrib.rnn.
  201. #
  202. # NB: The h activations of the last LSTM must be the last layer, in order
  203. # for _append_base_layers() to work.
  204. layers = []
  205. for index, num_units in enumerate(hidden_layer_sizes):
  206. layers.append(
  207. dragnn.Layer(component, name='state_c_%d' % index, dim=num_units))
  208. layers.append(
  209. dragnn.Layer(component, name='state_h_%d' % index, dim=num_units))
  210. context_layers = list(layers) # copy |layers|, don't alias it
  211. return layers, context_layers
  212. def create(self,
  213. fixed_embeddings,
  214. linked_embeddings,
  215. context_tensor_arrays,
  216. attention_tensor,
  217. during_training,
  218. stride=None):
  219. """See base class."""
  220. # NB: This cell pulls the lstm's h and c vectors from context_tensor_arrays
  221. # instead of through linked features.
  222. check.Eq(
  223. len(context_tensor_arrays), 2 * len(self._hidden_layer_sizes),
  224. 'require two context tensors per hidden layer')
  225. # Rearrange the context tensors into a tuple of LSTM sub-states.
  226. length = context_tensor_arrays[0].size()
  227. substates = []
  228. for index, num_units in enumerate(self._hidden_layer_sizes):
  229. state_c = context_tensor_arrays[2 * index].read(length - 1)
  230. state_h = context_tensor_arrays[2 * index + 1].read(length - 1)
  231. # Fix shapes that for some reason are not set properly for an unknown
  232. # reason. TODO(googleuser): Why are the shapes not set?
  233. state_c.set_shape([tf.Dimension(None), num_units])
  234. state_h.set_shape([tf.Dimension(None), num_units])
  235. substates.append(tf.contrib.rnn.LSTMStateTuple(state_c, state_h))
  236. state = tuple(substates)
  237. input_tensor = dragnn.get_input_tensor(fixed_embeddings, linked_embeddings)
  238. cell = self._train_cell if during_training else self._inference_cell
  239. def _cell_closure(scope):
  240. """Applies the LSTM cell to the current inputs and state."""
  241. return cell(input_tensor, state, scope)
  242. unused_h, state = self._apply_with_captured_variables(_cell_closure)
  243. # Return tensors to be put into the tensor arrays / used to compute
  244. # objective.
  245. output_tensors = []
  246. for new_substate in state:
  247. new_c, new_h = new_substate
  248. output_tensors.append(new_c)
  249. output_tensors.append(new_h)
  250. return self._append_base_layers(output_tensors)
  251. class BulkBiLSTMNetwork(BaseLSTMNetwork):
  252. """Bulk wrapper around tf.contrib.rnn.stack_bidirectional_dynamic_rnn().
  253. Features:
  254. lengths: [stride, 1] sequence lengths per batch item.
  255. All other features are concatenated into input activations.
  256. Subclass-specific layers:
  257. outputs: [stride * num_steps, self._output_dim] bi-LSTM activations.
  258. """
  259. def __init__(self, component):
  260. super(BulkBiLSTMNetwork, self).__init__(component)
  261. check.In('lengths', self._linked_feature_dims,
  262. 'Missing required linked feature')
  263. check.Eq(self._linked_feature_dims['lengths'], 1,
  264. 'Wrong dimension for "lengths" feature')
  265. self._input_dim = self._concatenated_input_dim - 1 # exclude 'lengths'
  266. self._output_dim = self.get_layer_size('outputs')
  267. tf.logging.info('[%s] Bulk bi-LSTM with input_dim=%d output_dim=%d',
  268. component.name, self._input_dim, self._output_dim)
  269. # Create one training and inference cell per layer and direction.
  270. self._train_cells_forward = self._create_train_cells()
  271. self._train_cells_backward = self._create_train_cells()
  272. self._inference_cells_forward = self._create_inference_cells()
  273. self._inference_cells_backward = self._create_inference_cells()
  274. def _bilstm_closure(scope):
  275. """Applies the bi-LSTM to placeholder inputs and lengths."""
  276. # Use singleton |stride| and |steps| because their values don't affect the
  277. # weight variables.
  278. stride, steps = 1, 1
  279. placeholder_inputs = tf.placeholder(
  280. dtype=tf.float32, shape=[stride, steps, self._input_dim])
  281. placeholder_lengths = tf.placeholder(dtype=tf.int64, shape=[stride])
  282. # Omit the initial states and sequence lengths for simplicity; they don't
  283. # affect the weight variables.
  284. tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
  285. self._train_cells_forward,
  286. self._train_cells_backward,
  287. placeholder_inputs,
  288. dtype=tf.float32,
  289. sequence_length=placeholder_lengths,
  290. scope=scope)
  291. self._capture_variables_as_params(_bilstm_closure)
  292. # Allocate parameters for the initial states. Note that an LSTM state is a
  293. # tuple of two substates (c, h), so there are 4 variables per layer.
  294. for index, num_units in enumerate(self._hidden_layer_sizes):
  295. for direction in ['forward', 'backward']:
  296. for substate in ['c', 'h']:
  297. self._params.append(
  298. tf.get_variable(
  299. 'initial_state_%s_%s_%d' % (direction, substate, index),
  300. [1, num_units], # leading 1 for later batch-wise tiling
  301. dtype=tf.float32,
  302. initializer=tf.constant_initializer(0.0)))
  303. def create_hidden_layers(self, component, hidden_layer_sizes):
  304. """See base class."""
  305. dim = 2 * hidden_layer_sizes[-1]
  306. return [dragnn.Layer(component, name='outputs', dim=dim)], []
  307. def create(self,
  308. fixed_embeddings,
  309. linked_embeddings,
  310. context_tensor_arrays,
  311. attention_tensor,
  312. during_training,
  313. stride=None):
  314. """Requires |stride|; otherwise see base class."""
  315. check.NotNone(stride,
  316. 'BulkBiLSTMNetwork requires "stride" and must be called '
  317. 'in the bulk feature extractor component.')
  318. # Flatten the lengths into a vector.
  319. lengths = dragnn.lookup_named_tensor('lengths', linked_embeddings)
  320. lengths_s = tf.squeeze(lengths.tensor, [1])
  321. # Collect all other inputs into a batched tensor.
  322. linked_embeddings = [
  323. named_tensor for named_tensor in linked_embeddings
  324. if named_tensor.name != 'lengths'
  325. ]
  326. inputs_sxnxd = dragnn.get_input_tensor_with_stride(
  327. fixed_embeddings, linked_embeddings, stride)
  328. # Since get_input_tensor_with_stride() concatenates the input embeddings, it
  329. # obscures the static activation dimension, which the RNN library requires.
  330. # Restore it using set_shape(). Note that set_shape() merges into the known
  331. # shape, so only specify the activation dimension.
  332. inputs_sxnxd.set_shape(
  333. [tf.Dimension(None), tf.Dimension(None), self._input_dim])
  334. initial_states_forward, initial_states_backward = (
  335. self._create_initial_states(stride))
  336. if during_training:
  337. cells_forward = self._train_cells_forward
  338. cells_backward = self._train_cells_backward
  339. else:
  340. cells_forward = self._inference_cells_forward
  341. cells_backward = self._inference_cells_backward
  342. def _bilstm_closure(scope):
  343. """Applies the bi-LSTM to the current inputs."""
  344. outputs_sxnxd, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
  345. cells_forward,
  346. cells_backward,
  347. inputs_sxnxd,
  348. initial_states_fw=initial_states_forward,
  349. initial_states_bw=initial_states_backward,
  350. sequence_length=lengths_s,
  351. scope=scope)
  352. return outputs_sxnxd
  353. # Layer outputs are not batched; flatten out the batch dimension.
  354. outputs_sxnxd = self._apply_with_captured_variables(_bilstm_closure)
  355. outputs_snxd = tf.reshape(outputs_sxnxd, [-1, self._output_dim])
  356. return self._append_base_layers([outputs_snxd])
  357. def _create_initial_states(self, stride):
  358. """Returns stacked and batched initial states for the bi-LSTM."""
  359. initial_states_forward = []
  360. initial_states_backward = []
  361. for index in range(len(self._hidden_layer_sizes)):
  362. # Retrieve the initial states for this layer.
  363. states_sxd = []
  364. for direction in ['forward', 'backward']:
  365. for substate in ['c', 'h']:
  366. state_1xd = self._component.get_variable('initial_state_%s_%s_%d' %
  367. (direction, substate, index))
  368. state_sxd = tf.tile(state_1xd, [stride, 1]) # tile across the batch
  369. states_sxd.append(state_sxd)
  370. # Assemble and append forward and backward LSTM states.
  371. initial_states_forward.append(
  372. tf.contrib.rnn.LSTMStateTuple(states_sxd[0], states_sxd[1]))
  373. initial_states_backward.append(
  374. tf.contrib.rnn.LSTMStateTuple(states_sxd[2], states_sxd[3]))
  375. return initial_states_forward, initial_states_backward