vgsl_model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """String network description language to define network layouts."""
  16. import re
  17. import time
  18. import decoder
  19. import errorcounter as ec
  20. import shapes
  21. import tensorflow as tf
  22. import vgsl_input
  23. import vgslspecs
  24. import tensorflow.contrib.slim as slim
  25. from tensorflow.core.framework import summary_pb2
  26. from tensorflow.python.platform import tf_logging as logging
  27. # Parameters for rate decay.
  28. # We divide the learning_rate_halflife by DECAY_STEPS_FACTOR and use DECAY_RATE
  29. # as the decay factor for the learning rate, ie we use the DECAY_STEPS_FACTORth
  30. # root of 2 as the decay rate every halflife/DECAY_STEPS_FACTOR to achieve the
  31. # desired halflife.
  32. DECAY_STEPS_FACTOR = 16
  33. DECAY_RATE = pow(0.5, 1.0 / DECAY_STEPS_FACTOR)
  34. def Train(train_dir,
  35. model_str,
  36. train_data,
  37. max_steps,
  38. master='',
  39. task=0,
  40. ps_tasks=0,
  41. initial_learning_rate=0.001,
  42. final_learning_rate=0.001,
  43. learning_rate_halflife=160000,
  44. optimizer_type='Adam',
  45. num_preprocess_threads=1,
  46. reader=None):
  47. """Testable trainer with no dependence on FLAGS.
  48. Args:
  49. train_dir: Directory to write checkpoints.
  50. model_str: Network specification string.
  51. train_data: Training data file pattern.
  52. max_steps: Number of training steps to run.
  53. master: Name of the TensorFlow master to use.
  54. task: Task id of this replica running the training. (0 will be master).
  55. ps_tasks: Number of tasks in ps job, or 0 if no ps job.
  56. initial_learning_rate: Learing rate at start of training.
  57. final_learning_rate: Asymptotic minimum learning rate.
  58. learning_rate_halflife: Number of steps over which to halve the difference
  59. between initial and final learning rate.
  60. optimizer_type: One of 'GradientDescent', 'AdaGrad', 'Momentum', 'Adam'.
  61. num_preprocess_threads: Number of input threads.
  62. reader: Function that returns an actual reader to read Examples from input
  63. files. If None, uses tf.TFRecordReader().
  64. """
  65. if master.startswith('local'):
  66. device = tf.ReplicaDeviceSetter(ps_tasks)
  67. else:
  68. device = '/cpu:0'
  69. with tf.Graph().as_default():
  70. with tf.device(device):
  71. model = InitNetwork(train_data, model_str, 'train', initial_learning_rate,
  72. final_learning_rate, learning_rate_halflife,
  73. optimizer_type, num_preprocess_threads, reader)
  74. # Create a Supervisor. It will take care of initialization, summaries,
  75. # checkpoints, and recovery.
  76. #
  77. # When multiple replicas of this program are running, the first one,
  78. # identified by --task=0 is the 'chief' supervisor. It is the only one
  79. # that takes case of initialization, etc.
  80. sv = tf.train.Supervisor(
  81. logdir=train_dir,
  82. is_chief=(task == 0),
  83. saver=model.saver,
  84. save_summaries_secs=10,
  85. save_model_secs=30,
  86. recovery_wait_secs=5)
  87. step = 0
  88. while step < max_steps:
  89. try:
  90. # Get an initialized, and possibly recovered session. Launch the
  91. # services: Checkpointing, Summaries, step counting.
  92. with sv.managed_session(master) as sess:
  93. while step < max_steps:
  94. _, step = model.TrainAStep(sess)
  95. if sv.coord.should_stop():
  96. break
  97. except tf.errors.AbortedError as e:
  98. logging.error('Received error:%s', e)
  99. continue
  100. def Eval(train_dir,
  101. eval_dir,
  102. model_str,
  103. eval_data,
  104. decoder_file,
  105. num_steps,
  106. graph_def_file=None,
  107. eval_interval_secs=0,
  108. reader=None):
  109. """Restores a model from a checkpoint and evaluates it.
  110. Args:
  111. train_dir: Directory to find checkpoints.
  112. eval_dir: Directory to write summary events.
  113. model_str: Network specification string.
  114. eval_data: Evaluation data file pattern.
  115. decoder_file: File to read to decode the labels.
  116. num_steps: Number of eval steps to run.
  117. graph_def_file: File to write graph definition to for freezing.
  118. eval_interval_secs: How often to run evaluations, or once if 0.
  119. reader: Function that returns an actual reader to read Examples from input
  120. files. If None, uses tf.TFRecordReader().
  121. Returns:
  122. (char error rate, word recall error rate, sequence error rate) as percent.
  123. Raises:
  124. ValueError: If unimplemented feature is used.
  125. """
  126. decode = None
  127. if decoder_file:
  128. decode = decoder.Decoder(decoder_file)
  129. # Run eval.
  130. rates = ec.ErrorRates(
  131. label_error=None,
  132. word_recall_error=None,
  133. word_precision_error=None,
  134. sequence_error=None)
  135. with tf.Graph().as_default():
  136. model = InitNetwork(eval_data, model_str, 'eval', reader=reader)
  137. sw = tf.summary.FileWriter(eval_dir)
  138. while True:
  139. sess = tf.Session('')
  140. if graph_def_file is not None:
  141. # Write the eval version of the graph to a file for freezing.
  142. if not tf.gfile.Exists(graph_def_file):
  143. with tf.gfile.FastGFile(graph_def_file, 'w') as f:
  144. f.write(
  145. sess.graph.as_graph_def(add_shapes=True).SerializeToString())
  146. ckpt = tf.train.get_checkpoint_state(train_dir)
  147. if ckpt and ckpt.model_checkpoint_path:
  148. step = model.Restore(ckpt.model_checkpoint_path, sess)
  149. if decode:
  150. rates = decode.SoftmaxEval(sess, model, num_steps)
  151. _AddRateToSummary('Label error rate', rates.label_error, step, sw)
  152. _AddRateToSummary('Word recall error rate', rates.word_recall_error,
  153. step, sw)
  154. _AddRateToSummary('Word precision error rate',
  155. rates.word_precision_error, step, sw)
  156. _AddRateToSummary('Sequence error rate', rates.sequence_error, step,
  157. sw)
  158. sw.flush()
  159. print 'Error rates=', rates
  160. else:
  161. raise ValueError('Non-softmax decoder evaluation not implemented!')
  162. if eval_interval_secs:
  163. time.sleep(eval_interval_secs)
  164. else:
  165. break
  166. return rates
  167. def InitNetwork(input_pattern,
  168. model_spec,
  169. mode='eval',
  170. initial_learning_rate=0.00005,
  171. final_learning_rate=0.00005,
  172. halflife=1600000,
  173. optimizer_type='Adam',
  174. num_preprocess_threads=1,
  175. reader=None):
  176. """Constructs a python tensor flow model defined by model_spec.
  177. Args:
  178. input_pattern: File pattern of the data in tfrecords of Example.
  179. model_spec: Concatenation of input spec, model spec and output spec.
  180. See Build below for input/output spec. For model spec, see vgslspecs.py
  181. mode: One of 'train', 'eval'
  182. initial_learning_rate: Initial learning rate for the network.
  183. final_learning_rate: Final learning rate for the network.
  184. halflife: Number of steps over which to halve the difference between
  185. initial and final learning rate for the network.
  186. optimizer_type: One of 'GradientDescent', 'AdaGrad', 'Momentum', 'Adam'.
  187. num_preprocess_threads: Number of threads to use for image processing.
  188. reader: Function that returns an actual reader to read Examples from input
  189. files. If None, uses tf.TFRecordReader().
  190. Eval tasks need only specify input_pattern and model_spec.
  191. Returns:
  192. A VGSLImageModel class.
  193. Raises:
  194. ValueError: if the model spec syntax is incorrect.
  195. """
  196. model = VGSLImageModel(mode, model_spec, initial_learning_rate,
  197. final_learning_rate, halflife)
  198. left_bracket = model_spec.find('[')
  199. right_bracket = model_spec.rfind(']')
  200. if left_bracket < 0 or right_bracket < 0:
  201. raise ValueError('Failed to find [] in model spec! ', model_spec)
  202. input_spec = model_spec[:left_bracket]
  203. layer_spec = model_spec[left_bracket:right_bracket + 1]
  204. output_spec = model_spec[right_bracket + 1:]
  205. model.Build(input_pattern, input_spec, layer_spec, output_spec,
  206. optimizer_type, num_preprocess_threads, reader)
  207. return model
  208. class VGSLImageModel(object):
  209. """Class that builds a tensor flow model for training or evaluation.
  210. """
  211. def __init__(self, mode, model_spec, initial_learning_rate,
  212. final_learning_rate, halflife):
  213. """Constructs a VGSLImageModel.
  214. Args:
  215. mode: One of "train", "eval"
  216. model_spec: Full model specification string, for reference only.
  217. initial_learning_rate: Initial learning rate for the network.
  218. final_learning_rate: Final learning rate for the network.
  219. halflife: Number of steps over which to halve the difference between
  220. initial and final learning rate for the network.
  221. """
  222. # The string that was used to build this model.
  223. self.model_spec = model_spec
  224. # The layers between input and output.
  225. self.layers = None
  226. # The train/eval mode.
  227. self.mode = mode
  228. # The initial learning rate.
  229. self.initial_learning_rate = initial_learning_rate
  230. self.final_learning_rate = final_learning_rate
  231. self.decay_steps = halflife / DECAY_STEPS_FACTOR
  232. self.decay_rate = DECAY_RATE
  233. # Tensor for the labels.
  234. self.labels = None
  235. self.sparse_labels = None
  236. # Debug data containing the truth text.
  237. self.truths = None
  238. # Tensor for loss
  239. self.loss = None
  240. # Train operation
  241. self.train_op = None
  242. # Tensor for the global step counter
  243. self.global_step = None
  244. # Tensor for the output predictions (usually softmax)
  245. self.output = None
  246. # True if we are using CTC training mode.
  247. self.using_ctc = False
  248. # Saver object to load or restore the variables.
  249. self.saver = None
  250. def Build(self, input_pattern, input_spec, model_spec, output_spec,
  251. optimizer_type, num_preprocess_threads, reader):
  252. """Builds the model from the separate input/layers/output spec strings.
  253. Args:
  254. input_pattern: File pattern of the data in tfrecords of TF Example format.
  255. input_spec: Specification of the input layer:
  256. batchsize,height,width,depth (4 comma-separated integers)
  257. Training will run with batches of batchsize images, but runtime can
  258. use any batch size.
  259. height and/or width can be 0 or -1, indicating variable size,
  260. otherwise all images must be the given size.
  261. depth must be 1 or 3 to indicate greyscale or color.
  262. NOTE 1-d image input, treating the y image dimension as depth, can
  263. be achieved using S1(1x0)1,3 as the first op in the model_spec, but
  264. the y-size of the input must then be fixed.
  265. model_spec: Model definition. See vgslspecs.py
  266. output_spec: Output layer definition:
  267. O(2|1|0)(l|s|c)n output layer with n classes.
  268. 2 (heatmap) Output is a 2-d vector map of the input (possibly at
  269. different scale).
  270. 1 (sequence) Output is a 1-d sequence of vector values.
  271. 0 (value) Output is a 0-d single vector value.
  272. l uses a logistic non-linearity on the output, allowing multiple
  273. hot elements in any output vector value.
  274. s uses a softmax non-linearity, with one-hot output in each value.
  275. c uses a softmax with CTC. Can only be used with s (sequence).
  276. NOTE Only O1s and O1c are currently supported.
  277. optimizer_type: One of 'GradientDescent', 'AdaGrad', 'Momentum', 'Adam'.
  278. num_preprocess_threads: Number of threads to use for image processing.
  279. reader: Function that returns an actual reader to read Examples from input
  280. files. If None, uses tf.TFRecordReader().
  281. """
  282. self.global_step = tf.Variable(0, name='global_step', trainable=False)
  283. shape = _ParseInputSpec(input_spec)
  284. out_dims, out_func, num_classes = _ParseOutputSpec(output_spec)
  285. self.using_ctc = out_func == 'c'
  286. images, heights, widths, labels, sparse, _ = vgsl_input.ImageInput(
  287. input_pattern, num_preprocess_threads, shape, self.using_ctc, reader)
  288. self.labels = labels
  289. self.sparse_labels = sparse
  290. self.layers = vgslspecs.VGSLSpecs(widths, heights, self.mode == 'train')
  291. last_layer = self.layers.Build(images, model_spec)
  292. self._AddOutputs(last_layer, out_dims, out_func, num_classes)
  293. if self.mode == 'train':
  294. self._AddOptimizer(optimizer_type)
  295. # For saving the model across training and evaluation
  296. self.saver = tf.train.Saver()
  297. def TrainAStep(self, sess):
  298. """Runs a training step in the session.
  299. Args:
  300. sess: Session in which to train the model.
  301. Returns:
  302. loss, global_step.
  303. """
  304. _, loss, step = sess.run([self.train_op, self.loss, self.global_step])
  305. return loss, step
  306. def Restore(self, checkpoint_path, sess):
  307. """Restores the model from the given checkpoint path into the session.
  308. Args:
  309. checkpoint_path: File pathname of the checkpoint.
  310. sess: Session in which to restore the model.
  311. Returns:
  312. global_step of the model.
  313. """
  314. self.saver.restore(sess, checkpoint_path)
  315. return tf.train.global_step(sess, self.global_step)
  316. def RunAStep(self, sess):
  317. """Runs a step for eval in the session.
  318. Args:
  319. sess: Session in which to run the model.
  320. Returns:
  321. output tensor result, labels tensor result.
  322. """
  323. return sess.run([self.output, self.labels])
  324. def _AddOutputs(self, prev_layer, out_dims, out_func, num_classes):
  325. """Adds the output layer and loss function.
  326. Args:
  327. prev_layer: Output of last layer of main network.
  328. out_dims: Number of output dimensions, 0, 1 or 2.
  329. out_func: Output non-linearity. 's' or 'c'=softmax, 'l'=logistic.
  330. num_classes: Number of outputs/size of last output dimension.
  331. """
  332. height_in = shapes.tensor_dim(prev_layer, dim=1)
  333. logits, outputs = self._AddOutputLayer(prev_layer, out_dims, out_func,
  334. num_classes)
  335. if self.mode == 'train':
  336. # Setup loss for training.
  337. self.loss = self._AddLossFunction(logits, height_in, out_dims, out_func)
  338. tf.summary.scalar('loss', self.loss)
  339. elif out_dims == 0:
  340. # Be sure the labels match the output, even in eval mode.
  341. self.labels = tf.slice(self.labels, [0, 0], [-1, 1])
  342. self.labels = tf.reshape(self.labels, [-1])
  343. logging.info('Final output=%s', outputs)
  344. logging.info('Labels tensor=%s', self.labels)
  345. self.output = outputs
  346. def _AddOutputLayer(self, prev_layer, out_dims, out_func, num_classes):
  347. """Add the fully-connected logits and SoftMax/Logistic output Layer.
  348. Args:
  349. prev_layer: Output of last layer of main network.
  350. out_dims: Number of output dimensions, 0, 1 or 2.
  351. out_func: Output non-linearity. 's' or 'c'=softmax, 'l'=logistic.
  352. num_classes: Number of outputs/size of last output dimension.
  353. Returns:
  354. logits: Pre-softmax/logistic fully-connected output shaped to out_dims.
  355. outputs: Post-softmax/logistic shaped to out_dims.
  356. Raises:
  357. ValueError: if syntax is incorrect.
  358. """
  359. # Reduce dimensionality appropriate to the output dimensions.
  360. batch_in = shapes.tensor_dim(prev_layer, dim=0)
  361. height_in = shapes.tensor_dim(prev_layer, dim=1)
  362. width_in = shapes.tensor_dim(prev_layer, dim=2)
  363. depth_in = shapes.tensor_dim(prev_layer, dim=3)
  364. if out_dims:
  365. # Combine any remaining height and width with batch and unpack after.
  366. shaped = tf.reshape(prev_layer, [-1, depth_in])
  367. else:
  368. # Everything except batch goes to depth, and therefore has to be known.
  369. shaped = tf.reshape(prev_layer, [-1, height_in * width_in * depth_in])
  370. logits = slim.fully_connected(shaped, num_classes, activation_fn=None)
  371. if out_func == 'l':
  372. raise ValueError('Logistic not yet supported!')
  373. else:
  374. output = tf.nn.softmax(logits)
  375. # Reshape to the dessired output.
  376. if out_dims == 2:
  377. output_shape = [batch_in, height_in, width_in, num_classes]
  378. elif out_dims == 1:
  379. output_shape = [batch_in, height_in * width_in, num_classes]
  380. else:
  381. output_shape = [batch_in, num_classes]
  382. output = tf.reshape(output, output_shape, name='Output')
  383. logits = tf.reshape(logits, output_shape)
  384. return logits, output
  385. def _AddLossFunction(self, logits, height_in, out_dims, out_func):
  386. """Add the appropriate loss function.
  387. Args:
  388. logits: Pre-softmax/logistic fully-connected output shaped to out_dims.
  389. height_in: Height of logits before going into the softmax layer.
  390. out_dims: Number of output dimensions, 0, 1 or 2.
  391. out_func: Output non-linearity. 's' or 'c'=softmax, 'l'=logistic.
  392. Returns:
  393. loss: That which is to be minimized.
  394. Raises:
  395. ValueError: if logistic is used.
  396. """
  397. if out_func == 'c':
  398. # Transpose batch to the middle.
  399. ctc_input = tf.transpose(logits, [1, 0, 2])
  400. # Compute the widths of each batch element from the input widths.
  401. widths = self.layers.GetLengths(dim=2, factor=height_in)
  402. cross_entropy = tf.nn.ctc_loss(ctc_input, self.sparse_labels, widths)
  403. elif out_func == 's':
  404. if out_dims == 2:
  405. self.labels = _PadLabels3d(logits, self.labels)
  406. elif out_dims == 1:
  407. self.labels = _PadLabels2d(
  408. shapes.tensor_dim(
  409. logits, dim=1), self.labels)
  410. else:
  411. self.labels = tf.slice(self.labels, [0, 0], [-1, 1])
  412. self.labels = tf.reshape(self.labels, [-1])
  413. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
  414. logits=logits, labels=self.labels, name='xent')
  415. else:
  416. # TODO(rays) Labels need an extra dimension for logistic, so different
  417. # padding functions are needed, as well as a different loss function.
  418. raise ValueError('Logistic not yet supported!')
  419. return tf.reduce_sum(cross_entropy)
  420. def _AddOptimizer(self, optimizer_type):
  421. """Adds an optimizer with learning rate decay to minimize self.loss.
  422. Args:
  423. optimizer_type: One of 'GradientDescent', 'AdaGrad', 'Momentum', 'Adam'.
  424. Raises:
  425. ValueError: if the optimizer type is unrecognized.
  426. """
  427. learn_rate_delta = self.initial_learning_rate - self.final_learning_rate
  428. learn_rate_dec = tf.add(
  429. tf.train.exponential_decay(learn_rate_delta, self.global_step,
  430. self.decay_steps, self.decay_rate),
  431. self.final_learning_rate)
  432. if optimizer_type == 'GradientDescent':
  433. opt = tf.train.GradientDescentOptimizer(learn_rate_dec)
  434. elif optimizer_type == 'AdaGrad':
  435. opt = tf.train.AdagradOptimizer(learn_rate_dec)
  436. elif optimizer_type == 'Momentum':
  437. opt = tf.train.MomentumOptimizer(learn_rate_dec, momentum=0.9)
  438. elif optimizer_type == 'Adam':
  439. opt = tf.train.AdamOptimizer(learning_rate=learn_rate_dec)
  440. else:
  441. raise ValueError('Invalid optimizer type: ' + optimizer_type)
  442. tf.summary.scalar('learn_rate', learn_rate_dec)
  443. self.train_op = opt.minimize(
  444. self.loss, global_step=self.global_step, name='train')
  445. def _PadLabels3d(logits, labels):
  446. """Pads or slices 3-d labels to match logits.
  447. Covers the case of 2-d softmax output, when labels is [batch, height, width]
  448. and logits is [batch, height, width, onehot]
  449. Args:
  450. logits: 4-d Pre-softmax fully-connected output.
  451. labels: 3-d, but not necessarily matching in size.
  452. Returns:
  453. labels: Resized by padding or clipping to match logits.
  454. """
  455. logits_shape = shapes.tensor_shape(logits)
  456. labels_shape = shapes.tensor_shape(labels)
  457. labels = tf.reshape(labels, [-1, labels_shape[2]])
  458. labels = _PadLabels2d(logits_shape[2], labels)
  459. labels = tf.reshape(labels, [labels_shape[0], -1])
  460. labels = _PadLabels2d(logits_shape[1] * logits_shape[2], labels)
  461. return tf.reshape(labels, [labels_shape[0], logits_shape[1], logits_shape[2]])
  462. def _PadLabels2d(logits_size, labels):
  463. """Pads or slices the 2nd dimension of 2-d labels to match logits_size.
  464. Covers the case of 1-d softmax output, when labels is [batch, seq] and
  465. logits is [batch, seq, onehot]
  466. Args:
  467. logits_size: Tensor returned from tf.shape giving the target size.
  468. labels: 2-d, but not necessarily matching in size.
  469. Returns:
  470. labels: Resized by padding or clipping the last dimension to logits_size.
  471. """
  472. pad = logits_size - tf.shape(labels)[1]
  473. def _PadFn():
  474. return tf.pad(labels, [[0, 0], [0, pad]])
  475. def _SliceFn():
  476. return tf.slice(labels, [0, 0], [-1, logits_size])
  477. return tf.cond(tf.greater(pad, 0), _PadFn, _SliceFn)
  478. def _ParseInputSpec(input_spec):
  479. """Parses input_spec and returns the numbers obtained therefrom.
  480. Args:
  481. input_spec: Specification of the input layer. See Build.
  482. Returns:
  483. shape: ImageShape with the desired shape of the input.
  484. Raises:
  485. ValueError: if syntax is incorrect.
  486. """
  487. pattern = re.compile(R'(\d+),(\d+),(\d+),(\d+)')
  488. m = pattern.match(input_spec)
  489. if m is None:
  490. raise ValueError('Failed to parse input spec:' + input_spec)
  491. batch_size = int(m.group(1))
  492. y_size = int(m.group(2)) if int(m.group(2)) > 0 else None
  493. x_size = int(m.group(3)) if int(m.group(3)) > 0 else None
  494. depth = int(m.group(4))
  495. if depth not in [1, 3]:
  496. raise ValueError('Depth must be 1 or 3, had:', depth)
  497. return vgsl_input.ImageShape(batch_size, y_size, x_size, depth)
  498. def _ParseOutputSpec(output_spec):
  499. """Parses the output spec.
  500. Args:
  501. output_spec: Output layer definition. See Build.
  502. Returns:
  503. out_dims: 2|1|0 for 2-d, 1-d, 0-d.
  504. out_func: l|s|c for logistic, softmax, softmax+CTC
  505. num_classes: Number of classes in output.
  506. Raises:
  507. ValueError: if syntax is incorrect.
  508. """
  509. pattern = re.compile(R'(O)(0|1|2)(l|s|c)(\d+)')
  510. m = pattern.match(output_spec)
  511. if m is None:
  512. raise ValueError('Failed to parse output spec:' + output_spec)
  513. out_dims = int(m.group(2))
  514. out_func = m.group(3)
  515. if out_func == 'c' and out_dims != 1:
  516. raise ValueError('CTC can only be used with a 1-D sequence!')
  517. num_classes = int(m.group(4))
  518. return out_dims, out_func, num_classes
  519. def _AddRateToSummary(tag, rate, step, sw):
  520. """Adds the given rate to the summary with the given tag.
  521. Args:
  522. tag: Name for this value.
  523. rate: Value to add to the summary. Perhaps an error rate.
  524. step: Global step of the graph for the x-coordinate of the summary.
  525. sw: Summary writer to which to write the rate value.
  526. """
  527. sw.add_summary(
  528. summary_pb2.Summary(value=[summary_pb2.Summary.Value(
  529. tag=tag, simple_value=rate)]), step)