# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Full license terms provided in LICENSE.md file. import tensorflow as tf import sys def makeConst6(const6_name='const6'): graph = tf.Graph() with graph.as_default(): tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) return graph.as_graph_def() # create base relu6 nodes def makeRelu6(output_name, input_name, const6_name='const6'): graph = tf.Graph() with graph.as_default(): tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name) tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name) with tf.name_scope(output_name): tf_y1 = tf.nn.relu(tf_x, name='relu1') tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2') #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2') #tf_y = tf.subtract(tf_6, tf_y, name=output_name) tf_y = tf.subtract(tf_y1, tf_y2, name=output_name) graph_def = graph.as_graph_def() graph_def.node[-1].name = output_name # remove unused nodes for node in graph_def.node: if node.name == input_name: graph_def.node.remove(node) for node in graph_def.node: if node.name == const6_name: graph_def.node.remove(node) for node in graph_def.node: if node.op == '_Neg': node.op = 'Neg' return graph_def def convertRelu6(graph_def, const6_name='const6'): # add constant 6 has_const6 = False for node in graph_def.node: if node.name == const6_name: has_const6 = True if not has_const6: const6_graph_def = makeConst6(const6_name=const6_name) graph_def.node.extend(const6_graph_def.node) for node in graph_def.node: if node.op == 'Relu6': input_name = node.input[0] output_name = node.name relu6_graph_def = makeRelu6(output_name, input_name, const6_name=const6_name) graph_def.node.remove(node) graph_def.node.extend(relu6_graph_def.node) return graph_def if __name__ == '__main__': if len(sys.argv) != 3: print("Replaces Relu6 nodes in a frozen graph with Relu(x) - Relu(x-6)\n") print("Usage: python convert_relu6.py ") exit() with open(sys.argv[1], 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) converted_graph_def = convertRelu6(graph_def) with open(sys.argv[2], 'wb') as f: f.write(converted_graph_def.SerializeToString()) f.close()