123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- # Full license terms provided in LICENSE.md file.
- import tensorflow as tf
- import sys
- def makeConst6(const6_name='const6'):
- graph = tf.Graph()
- with graph.as_default():
- tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
- return graph.as_graph_def()
- # create base relu6 nodes
- def makeRelu6(output_name, input_name, const6_name='const6'):
- graph = tf.Graph()
- with graph.as_default():
- tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name)
- tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
- with tf.name_scope(output_name):
- tf_y1 = tf.nn.relu(tf_x, name='relu1')
- tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2')
- #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2')
- #tf_y = tf.subtract(tf_6, tf_y, name=output_name)
- tf_y = tf.subtract(tf_y1, tf_y2, name=output_name)
-
- graph_def = graph.as_graph_def()
- graph_def.node[-1].name = output_name
- # remove unused nodes
- for node in graph_def.node:
- if node.name == input_name:
- graph_def.node.remove(node)
- for node in graph_def.node:
- if node.name == const6_name:
- graph_def.node.remove(node)
- for node in graph_def.node:
- if node.op == '_Neg':
- node.op = 'Neg'
-
- return graph_def
- def convertRelu6(graph_def, const6_name='const6'):
- # add constant 6
- has_const6 = False
- for node in graph_def.node:
- if node.name == const6_name:
- has_const6 = True
- if not has_const6:
- const6_graph_def = makeConst6(const6_name=const6_name)
- graph_def.node.extend(const6_graph_def.node)
-
- for node in graph_def.node:
- if node.op == 'Relu6':
- input_name = node.input[0]
- output_name = node.name
- relu6_graph_def = makeRelu6(output_name, input_name, const6_name=const6_name)
- graph_def.node.remove(node)
- graph_def.node.extend(relu6_graph_def.node)
-
- return graph_def
- if __name__ == '__main__':
- if len(sys.argv) != 3:
- print("Replaces Relu6 nodes in a frozen graph with Relu(x) - Relu(x-6)\n")
- print("Usage: python convert_relu6.py <frozen_graph_path> <output_frozen_graph_path>")
- exit()
- with open(sys.argv[1], 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- converted_graph_def = convertRelu6(graph_def)
- with open(sys.argv[2], 'wb') as f:
- f.write(converted_graph_def.SerializeToString())
- f.close()
|