losses.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains convenience wrappers for various Neural Network TensorFlow losses.
  16. All the losses defined here add themselves to the LOSSES_COLLECTION
  17. collection.
  18. l1_loss: Define a L1 Loss, useful for regularization, i.e. lasso.
  19. l2_loss: Define a L2 Loss, useful for regularization, i.e. weight decay.
  20. cross_entropy_loss: Define a cross entropy loss using
  21. softmax_cross_entropy_with_logits. Useful for classification.
  22. """
  23. from __future__ import absolute_import
  24. from __future__ import division
  25. from __future__ import print_function
  26. import tensorflow as tf
  27. # In order to gather all losses in a network, the user should use this
  28. # key for get_collection, i.e:
  29. # losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
  30. LOSSES_COLLECTION = '_losses'
  31. def l1_regularizer(weight=1.0, scope=None):
  32. """Define a L1 regularizer.
  33. Args:
  34. weight: scale the loss by this factor.
  35. scope: Optional scope for name_scope.
  36. Returns:
  37. a regularizer function.
  38. """
  39. def regularizer(tensor):
  40. with tf.name_scope(scope, 'L1Regularizer', [tensor]):
  41. l1_weight = tf.convert_to_tensor(weight,
  42. dtype=tensor.dtype.base_dtype,
  43. name='weight')
  44. return tf.multiply(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
  45. return regularizer
  46. def l2_regularizer(weight=1.0, scope=None):
  47. """Define a L2 regularizer.
  48. Args:
  49. weight: scale the loss by this factor.
  50. scope: Optional scope for name_scope.
  51. Returns:
  52. a regularizer function.
  53. """
  54. def regularizer(tensor):
  55. with tf.name_scope(scope, 'L2Regularizer', [tensor]):
  56. l2_weight = tf.convert_to_tensor(weight,
  57. dtype=tensor.dtype.base_dtype,
  58. name='weight')
  59. return tf.multiply(l2_weight, tf.nn.l2_loss(tensor), name='value')
  60. return regularizer
  61. def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
  62. """Define a L1L2 regularizer.
  63. Args:
  64. weight_l1: scale the L1 loss by this factor.
  65. weight_l2: scale the L2 loss by this factor.
  66. scope: Optional scope for name_scope.
  67. Returns:
  68. a regularizer function.
  69. """
  70. def regularizer(tensor):
  71. with tf.name_scope(scope, 'L1L2Regularizer', [tensor]):
  72. weight_l1_t = tf.convert_to_tensor(weight_l1,
  73. dtype=tensor.dtype.base_dtype,
  74. name='weight_l1')
  75. weight_l2_t = tf.convert_to_tensor(weight_l2,
  76. dtype=tensor.dtype.base_dtype,
  77. name='weight_l2')
  78. reg_l1 = tf.multiply(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
  79. name='value_l1')
  80. reg_l2 = tf.multiply(weight_l2_t, tf.nn.l2_loss(tensor),
  81. name='value_l2')
  82. return tf.add(reg_l1, reg_l2, name='value')
  83. return regularizer
  84. def l1_loss(tensor, weight=1.0, scope=None):
  85. """Define a L1Loss, useful for regularize, i.e. lasso.
  86. Args:
  87. tensor: tensor to regularize.
  88. weight: scale the loss by this factor.
  89. scope: Optional scope for name_scope.
  90. Returns:
  91. the L1 loss op.
  92. """
  93. with tf.name_scope(scope, 'L1Loss', [tensor]):
  94. weight = tf.convert_to_tensor(weight,
  95. dtype=tensor.dtype.base_dtype,
  96. name='loss_weight')
  97. loss = tf.multiply(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
  98. tf.add_to_collection(LOSSES_COLLECTION, loss)
  99. return loss
  100. def l2_loss(tensor, weight=1.0, scope=None):
  101. """Define a L2Loss, useful for regularize, i.e. weight decay.
  102. Args:
  103. tensor: tensor to regularize.
  104. weight: an optional weight to modulate the loss.
  105. scope: Optional scope for name_scope.
  106. Returns:
  107. the L2 loss op.
  108. """
  109. with tf.name_scope(scope, 'L2Loss', [tensor]):
  110. weight = tf.convert_to_tensor(weight,
  111. dtype=tensor.dtype.base_dtype,
  112. name='loss_weight')
  113. loss = tf.multiply(weight, tf.nn.l2_loss(tensor), name='value')
  114. tf.add_to_collection(LOSSES_COLLECTION, loss)
  115. return loss
  116. def cross_entropy_loss(logits, one_hot_labels, label_smoothing=0,
  117. weight=1.0, scope=None):
  118. """Define a Cross Entropy loss using softmax_cross_entropy_with_logits.
  119. It can scale the loss by weight factor, and smooth the labels.
  120. Args:
  121. logits: [batch_size, num_classes] logits outputs of the network .
  122. one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
  123. label_smoothing: if greater than 0 then smooth the labels.
  124. weight: scale the loss by this factor.
  125. scope: Optional scope for name_scope.
  126. Returns:
  127. A tensor with the softmax_cross_entropy loss.
  128. """
  129. logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
  130. with tf.name_scope(scope, 'CrossEntropyLoss', [logits, one_hot_labels]):
  131. num_classes = one_hot_labels.get_shape()[-1].value
  132. one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
  133. if label_smoothing > 0:
  134. smooth_positives = 1.0 - label_smoothing
  135. smooth_negatives = label_smoothing / num_classes
  136. one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
  137. cross_entropy = tf.contrib.nn.deprecated_flipped_softmax_cross_entropy_with_logits(
  138. logits, one_hot_labels, name='xentropy')
  139. weight = tf.convert_to_tensor(weight,
  140. dtype=logits.dtype.base_dtype,
  141. name='loss_weight')
  142. loss = tf.multiply(weight, tf.reduce_mean(cross_entropy), name='value')
  143. tf.add_to_collection(LOSSES_COLLECTION, loss)
  144. return loss