losses.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains convenience wrappers for various Neural Network TensorFlow losses.
  16. All the losses defined here add themselves to the LOSSES_COLLECTION
  17. collection.
  18. l1_loss: Define a L1 Loss, useful for regularization, i.e. lasso.
  19. l2_loss: Define a L2 Loss, useful for regularization, i.e. weight decay.
  20. cross_entropy_loss: Define a cross entropy loss using
  21. softmax_cross_entropy_with_logits. Useful for classification.
  22. """
  23. from __future__ import absolute_import
  24. from __future__ import division
  25. from __future__ import print_function
  26. import tensorflow as tf
  27. # In order to gather all losses in a network, the user should use this
  28. # key for get_collection, i.e:
  29. # losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
  30. LOSSES_COLLECTION = '_losses'
  31. def l1_loss(tensor, weight=1.0, scope=None):
  32. """Define a L1Loss, useful for regularize, i.e. lasso.
  33. Args:
  34. tensor: tensor to regularize.
  35. weight: scale the loss by this factor.
  36. scope: Optional scope for op_scope.
  37. Returns:
  38. the L1 loss op.
  39. """
  40. with tf.op_scope([tensor], scope, 'L1Loss'):
  41. weight = tf.convert_to_tensor(weight,
  42. dtype=tensor.dtype.base_dtype,
  43. name='loss_weight')
  44. loss = tf.mul(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
  45. tf.add_to_collection(LOSSES_COLLECTION, loss)
  46. return loss
  47. def l2_loss(tensor, weight=1.0, scope=None):
  48. """Define a L2Loss, useful for regularize, i.e. weight decay.
  49. Args:
  50. tensor: tensor to regularize.
  51. weight: an optional weight to modulate the loss.
  52. scope: Optional scope for op_scope.
  53. Returns:
  54. the L2 loss op.
  55. """
  56. with tf.op_scope([tensor], scope, 'L2Loss'):
  57. weight = tf.convert_to_tensor(weight,
  58. dtype=tensor.dtype.base_dtype,
  59. name='loss_weight')
  60. loss = tf.mul(weight, tf.nn.l2_loss(tensor), name='value')
  61. tf.add_to_collection(LOSSES_COLLECTION, loss)
  62. return loss
  63. def cross_entropy_loss(logits, one_hot_labels, label_smoothing=0,
  64. weight=1.0, scope=None):
  65. """Define a Cross Entropy loss using softmax_cross_entropy_with_logits.
  66. It can scale the loss by weight factor, and smooth the labels.
  67. Args:
  68. logits: [batch_size, num_classes] logits outputs of the network .
  69. one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
  70. label_smoothing: if greater than 0 then smooth the labels.
  71. weight: scale the loss by this factor.
  72. scope: Optional scope for op_scope.
  73. Returns:
  74. A tensor with the softmax_cross_entropy loss.
  75. """
  76. logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
  77. with tf.op_scope([logits, one_hot_labels], scope, 'CrossEntropyLoss'):
  78. num_classes = one_hot_labels.get_shape()[-1].value
  79. one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
  80. if label_smoothing > 0:
  81. smooth_positives = 1.0 - label_smoothing
  82. smooth_negatives = label_smoothing / num_classes
  83. one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
  84. cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
  85. one_hot_labels,
  86. name='xentropy')
  87. weight = tf.convert_to_tensor(weight,
  88. dtype=logits.dtype.base_dtype,
  89. name='loss_weight')
  90. loss = tf.mul(weight, tf.reduce_mean(cross_entropy), name='value')
  91. tf.add_to_collection(LOSSES_COLLECTION, loss)
  92. return loss