convert_relu6.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import tensorflow as tf
  4. import sys
  5. def makeConst6(const6_name='const6'):
  6. graph = tf.Graph()
  7. with graph.as_default():
  8. tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
  9. return graph.as_graph_def()
  10. # create base relu6 nodes
  11. def makeRelu6(output_name, input_name, const6_name='const6'):
  12. graph = tf.Graph()
  13. with graph.as_default():
  14. tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name)
  15. tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
  16. with tf.name_scope(output_name):
  17. tf_y1 = tf.nn.relu(tf_x, name='relu1')
  18. tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2')
  19. #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2')
  20. #tf_y = tf.subtract(tf_6, tf_y, name=output_name)
  21. tf_y = tf.subtract(tf_y1, tf_y2, name=output_name)
  22. graph_def = graph.as_graph_def()
  23. graph_def.node[-1].name = output_name
  24. # remove unused nodes
  25. for node in graph_def.node:
  26. if node.name == input_name:
  27. graph_def.node.remove(node)
  28. for node in graph_def.node:
  29. if node.name == const6_name:
  30. graph_def.node.remove(node)
  31. for node in graph_def.node:
  32. if node.op == '_Neg':
  33. node.op = 'Neg'
  34. return graph_def
  35. def convertRelu6(graph_def, const6_name='const6'):
  36. # add constant 6
  37. has_const6 = False
  38. for node in graph_def.node:
  39. if node.name == const6_name:
  40. has_const6 = True
  41. if not has_const6:
  42. const6_graph_def = makeConst6(const6_name=const6_name)
  43. graph_def.node.extend(const6_graph_def.node)
  44. for node in graph_def.node:
  45. if node.op == 'Relu6':
  46. input_name = node.input[0]
  47. output_name = node.name
  48. relu6_graph_def = makeRelu6(output_name, input_name, const6_name=const6_name)
  49. graph_def.node.remove(node)
  50. graph_def.node.extend(relu6_graph_def.node)
  51. return graph_def
  52. if __name__ == '__main__':
  53. if len(sys.argv) != 3:
  54. print("Replaces Relu6 nodes in a frozen graph with Relu(x) - Relu(x-6)\n")
  55. print("Usage: python convert_relu6.py <frozen_graph_path> <output_frozen_graph_path>")
  56. exit()
  57. with open(sys.argv[1], 'rb') as f:
  58. graph_def = tf.GraphDef()
  59. graph_def.ParseFromString(f.read())
  60. converted_graph_def = convertRelu6(graph_def)
  61. with open(sys.argv[2], 'wb') as f:
  62. f.write(converted_graph_def.SerializeToString())
  63. f.close()