Browse Source

initial commit

John Welsh 6 years ago
commit
0aced03dc4

+ 13 - 0
CMakeLists.txt

@@ -0,0 +1,13 @@
+cmake_minimum_required(VERSION 3.1)
+project(trt_image_classification)
+
+find_package(OpenCV REQUIRED)
+find_package(CUDA REQUIRED)
+
+include_directories(${CMAKE_SOURCE_DIR})
+
+set(CUDA_NVCC_FLAGS --std=c++11)
+set(CMAKE_CXX_STANDARD 11)
+
+add_subdirectory(examples)
+add_subdirectory(src)

+ 34 - 0
INSTALL.md

@@ -0,0 +1,34 @@
+Installation
+===
+
+1. Flash the Jetson TX2 using JetPack 3.2.  Be sure to install
+  * CUDA 9.0
+  * OpenCV4Tegra
+  * cuDNN
+  * TensorRT 3.0
+
+2. Install TensorFlow on Jetson TX2.
+  1. ...
+  2. ...
+
+3. Install uff converter on Jetson TX2.
+  1. Download TensorRT 3.0.4 for Ubuntu 16.04 and CUDA 9.0 tar package from https://developer.nvidia.com/nvidia-tensorrt-download.
+  2. Extract archive 
+
+            tar -xzf TensorRT-3.0.4.Ubuntu-16.04.3.x86_64.cuda-9.0.cudnn7.0.tar.gz
+
+  3. Install uff python package using pip 
+
+            sudo pip install TensorRT-3.0.4/uff/uff-0.2.0-py2.py3-none-any.whl
+
+4. Clone and build this project
+
+    ```
+    git clone --recursive https://gitlab-master.nvidia.com/jwelsh/trt_image_classification.git
+    cd trt_image_classification
+    mkdir build
+    cd build
+    cmake ..
+    make 
+    cd ..
+    ```

+ 25 - 0
LICENSE.md

@@ -0,0 +1,25 @@
+Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+ * Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+ * Neither the name of NVIDIA CORPORATION nor the names of its
+   contributors may be used to endorse or promote products derived
+   from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 78 - 0
README.md

@@ -0,0 +1,78 @@
+TensorFlow->TensorRT Image Classification
+===
+
+This contains examples, scripts and code related to image classification using TensorFlow models
+(from [here](https://github.com/tensorflow/models/tree/master/research/slim#Pretrained))
+converted to TensorRT.  Converting TensorFlow models to TensorRT offers significant performance
+gains on the Jetson TX2 as seen [below](#default_models).
+
+<a name="quick_start"></a>
+## Quick Start
+
+1. Follow the [installation guide](INSTALL.md).
+2. Download the pretrained TensorFlow models and example images.
+
+    ```
+    source scripts/download_models.sh
+    source scripts/download_images.sh
+    ```
+
+3. Convert the pretrained models to frozen graphs.
+
+    ```
+    python scripts/models_to_frozen_graphs.py
+    ```
+
+4. Convert the frozen graphs to optimized TensorRT engines.
+
+    ```
+    python scripts/frozen_graphs_to_plans.py
+    ```
+
+5. Execute the Inception V1 model on a single image.
+
+    ```
+    ./build/examples/classify_image/classify_image data/images/gordon_setter.jpg data/plans/inception_v1.plan data/imagenet_labels_1001.txt input InceptionV1/Logits/SpatialSqueeze inception
+    ```
+
+For more details, read through the examples [link](examples/README.md).
+
+<a name="default_models"></a>
+## Default Models
+
+The table below shows various details related to the default models ported from the TensorFlow 
+slim model zoo.  
+
+| <sub>Model</sub> | <sub>Input Size</sub> | <sub>TensorRT (TX2 / Half)</sub> | <sub>TensorRT (TX2 / Float)</sub> | <sub>TensorFlow (TX2 / Float)</sub> | <sub>Input Name</sub> | <sub>Output Name</sub> | <sub>Preprocessing Fn.</sub> |
+|--- |:---:|:---:|:---:|:---:|---|---|---|
+| <sub>inception_v1</sub> | <sub>224x224</sub> | <sub>7.98ms</sub> | <sub>12.8ms</sub> | <sub>27.6ms</sub> | <sub>input</sub> | <sub>InceptionV1/Logits/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>inception_v3</sub> | <sub>299x299</sub> | <sub>26.3ms</sub> | <sub>46.1ms</sub> | <sub>98.4ms</sub> | <sub>input</sub> | <sub>InceptionV3/Logits/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>inception_v4</sub> | <sub>299x299</sub> | <sub>52.1ms</sub> | <sub>88.2ms</sub> | <sub>176ms</sub> | <sub>input</sub> | <sub>InceptionV4/Logits/Logits/BiasAdd</sub> | <sub>inception</sub> |
+| <sub>inception_resnet_v2</sub> | <sub>299x299</sub> | <sub>53.0ms</sub> | <sub>98.7ms</sub> | <sub>168ms</sub> | <sub>input</sub> | <sub>InceptionResnetV2/Logits/Logits/BiasAdd</sub> | <sub>inception</sub> |
+| <sub>resnet_v1_50</sub> | <sub>224x224</sub> | <sub>15.7ms</sub> | <sub>27.1ms</sub> | <sub>63.9ms</sub> | <sub>input</sub> | <sub>resnet_v1_50/SpatialSqueeze</sub> | <sub>vgg</sub> |
+| <sub>resnet_v1_101</sub> | <sub>224x224</sub> | <sub>29.9ms</sub> | <sub>51.8ms</sub> | <sub>107ms</sub> | <sub>input</sub> | <sub>resnet_v1_101/SpatialSqueeze</sub> | <sub>vgg</sub> |
+| <sub>resnet_v1_152</sub> | <sub>224x224</sub> | <sub>42.6ms</sub> | <sub>78.2ms</sub> | <sub>157ms</sub> | <sub>input</sub> | <sub>resnet_v1_152/SpatialSqueeze</sub> | <sub>vgg</sub> |
+| <sub>resnet_v2_50</sub> | <sub>299x299</sub> | <sub>27.5ms</sub> | <sub>44.4ms</sub> | <sub>92.2ms</sub> | <sub>input</sub> | <sub>resnet_v2_50/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>resnet_v2_101</sub> | <sub>299x299</sub> | <sub>49.2ms</sub> | <sub>83.1ms</sub> | <sub>160ms</sub> | <sub>input</sub> | <sub>resnet_v2_101/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>resnet_v2_152</sub> | <sub>299x299</sub> | <sub>74.6ms</sub> | <sub>124ms</sub> | <sub>230ms</sub> | <sub>input</sub> | <sub>resnet_v2_152/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>mobilenet_v1_0p25_128</sub> | <sub>128x128</sub> | <sub>2.67ms</sub> | <sub>2.65ms</sub> | <sub>15.7ms</sub> | <sub>input</sub> | <sub>MobilenetV1/Logits/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>mobilenet_v1_0p5_160</sub> | <sub>160x160</sub> | <sub>3.95ms</sub> | <sub>4.00ms</sub> | <sub>16.9ms</sub> | <sub>input</sub> | <sub>MobilenetV1/Logits/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>mobilenet_v1_1p0_224</sub> | <sub>224x224</sub> | <sub>12.9ms</sub> | <sub>12.9ms</sub> | <sub>24.4ms</sub> | <sub>input</sub> | <sub>MobilenetV1/Logits/SpatialSqueeze</sub> | <sub>inception</sub> |
+| <sub>vgg_16</sub> | <sub>224x224</sub> | <sub>38.2ms</sub> | <sub>79.2ms</sub> | <sub>171ms</sub> | <sub>input</sub> | <sub>vgg_16/fc8/BiasAdd</sub> | <sub>vgg</sub> |
+
+<!--| inception_v2 | 224x224 | 10.3ms | 16.9ms | 38.3ms | input | InceptionV2/Logits/SpatialSqueeze | inception |-->
+<!--| vgg_19 | 224x224 | 97.3ms | OOM | input | vgg_19/fc8/BiasAdd | vgg |-->
+
+
+The times recorded include data transfer to GPU, network execution, and
+data transfer back from GPU.  Time does not include preprocessing. 
+See **scripts/test_tf.py**, **scripts/test_trt.py**, and **src/test/test_trt.cu** 
+for implementation details.  To reproduce the timings run
+
+```
+python scripts/test_tf.py
+python scripts/test_trt.py
+```
+
+The timing results will be located in **data/test_output_tf.txt** and **data/test_output_trt.txt**.  Note
+that you must download and convert the models (as in the quick start) prior to running the benchmark scripts.

File diff suppressed because it is too large
+ 1000 - 0
data/imagenet_labels_1000.txt


File diff suppressed because it is too large
+ 1001 - 0
data/imagenet_labels_1001.txt


+ 1 - 0
examples/CMakeLists.txt

@@ -0,0 +1 @@
+add_subdirectory(classify_image)

+ 6 - 0
examples/README.md

@@ -0,0 +1,6 @@
+Examples
+===
+
+This directory contains examples related to image classification with TensorRT. 
+
+[Classify Image](classify_image/README.md)

+ 2 - 0
examples/classify_image/CMakeLists.txt

@@ -0,0 +1,2 @@
+cuda_add_executable(classify_image classify_image.cu)
+target_link_libraries(classify_image ${OpenCV_LIBS} nvinfer nvparsers)

+ 53 - 0
examples/classify_image/README.md

@@ -0,0 +1,53 @@
+# Classify Single Image 
+
+This example demonstrates how to classify a single image using one of the TensorFlow pretrained
+models converted to TensorRT.
+
+
+## Running the Example
+
+**If you haven't already, follow the installation instructions [here](../../INSTALL.md).**
+
+### Convert Frozen Model to PLAN
+
+Assuming you have a trained and frozen image classification model, convert it to a plan
+using the following convert_plan script.
+
+```
+python scripts/convert_plan.py data/frozen_graphs/inception_v1.pb data/plans/inception_v1.plan input 224 224 InceptionV1/Logits/SpatialSqueeze 1 1048576 float
+```
+
+For reference, the inputs to the convert_plan.py script are
+
+1. frozen graph path
+2. output plan path
+3. input node name
+4. input height
+5. input width
+6. output node name
+7. max batch size
+8. max workspace size
+9. data type (float or half)
+
+### Run the example program
+
+Once the plan file is generated, run the example Cpp/CUDA program to classify the image.
+
+```
+./build/examples/classify_image/classify_image data/images/gordon_setter.jpg data/plans/inception_v1.plan data/imagenet_labels_1001.txt input InceptionV1/Logits/SpatialSqueeze inception
+```
+
+You should see that the most probable index is 215, which using our label file corresponds to
+a "Gordon setter".  For reference, the inputs to the classify_image example are
+
+1. input image path
+2. plan file path
+3. labels file (one label per line, line number corresponds to index in output)
+4. input node name
+5. output node name
+6. preprocessing function (either vgg or inception. see: [this](../../data/DEFAULT_NETS.md))
+
+To use other networks, supply the classify_image executable with arguments corresponding to the
+default networks table ([link](../../data/DEFAULT_NETS.md)).  You will need to generate the 
+corresponding PLAN file as well.
+

+ 141 - 0
examples/classify_image/classify_image.cu

@@ -0,0 +1,141 @@
+/**
+ * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+ * Full license terms provided in LICENSE.md file.
+ */
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+#include <NvInfer.h>
+#include <opencv2/opencv.hpp>
+#include "examples/classify_image/utils.h"
+
+
+using namespace std;
+using namespace nvinfer1;
+
+
+class Logger : public ILogger
+{
+  void log(Severity severity, const char * msg) override
+  {
+    if (severity != Severity::kINFO)
+      cout << msg << endl;
+  }
+} gLogger;
+
+
+/**
+ * image_file: path to image
+ * plan_file: path of the serialized engine file
+ * label_file: file with <class_name> per line
+ * input_name: name of the input tensor
+ * output_name: name of the output tensor
+ * preprocessing_fn: 'vgg' or 'inception'
+ */
+int main(int argc, char *argv[])
+{
+  if (argc != 7)
+  {
+    cout << "Usage: classify_image <image_file> <plan_file> <label_file> <input_name> <output_name> <preprocessing_fn>\n";
+    return 0;
+  }
+
+  string imageFilename = argv[1];
+  string planFilename = argv[2];
+  string labelFilename = argv[3];
+  string inputName = argv[4];
+  string outputName = argv[5];
+  string preprocessingFn = argv[6];
+
+  /* load the engine */
+  cout << "Loading TensorRT engine from plan file..." << endl;
+  ifstream planFile(planFilename); 
+  stringstream planBuffer;
+  planBuffer << planFile.rdbuf();
+  string plan = planBuffer.str();
+  IRuntime *runtime = createInferRuntime(gLogger);
+  ICudaEngine *engine = runtime->deserializeCudaEngine((void*)plan.data(), plan.size(), nullptr);
+  IExecutionContext *context = engine->createExecutionContext();
+  
+  /* get the input / output dimensions */
+  int inputBindingIndex, outputBindingIndex;
+  inputBindingIndex = engine->getBindingIndex(inputName.c_str());
+  outputBindingIndex = engine->getBindingIndex(outputName.c_str());
+  Dims inputDims, outputDims;
+  inputDims = engine->getBindingDimensions(inputBindingIndex);
+  outputDims = engine->getBindingDimensions(outputBindingIndex);
+  int inputWidth, inputHeight;
+  inputHeight = inputDims.d[1];
+  inputWidth = inputDims.d[2];
+
+  /* read image, convert color, and resize */
+  cout << "Preprocessing input..." << endl;
+  cv::Mat image = cv::imread(imageFilename, CV_LOAD_IMAGE_COLOR);
+  cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
+  cv::resize(image, image, cv::Size(inputWidth, inputHeight));
+
+  /* convert from uint8+NHWC to float+NCHW */
+  float *inputDataHost, *outputDataHost;
+  size_t numInput, numOutput;
+  numInput = numTensorElements(inputDims);
+  numOutput = numTensorElements(outputDims);
+  inputDataHost = (float*) malloc(numInput * sizeof(float));
+  outputDataHost = (float*) malloc(numOutput * sizeof(float));
+  cvImageToTensor(image, inputDataHost, inputDims);
+  if (preprocessingFn == "vgg")
+    preprocessVgg(inputDataHost, inputDims);
+  else if (preprocessingFn == "inception")
+    preprocessInception(inputDataHost, inputDims);
+  else
+    cout << "Unsupported preprocessing function.  Results may not be correct.\n" << endl;
+
+  /* transfer to device */
+  float *inputDataDevice, *outputDataDevice;
+  cudaMalloc(&inputDataDevice, numInput * sizeof(float));
+  cudaMalloc(&outputDataDevice, numOutput * sizeof(float));
+  cudaMemcpy(inputDataDevice, inputDataHost, numInput * sizeof(float), cudaMemcpyHostToDevice);
+  void *bindings[2];
+  bindings[inputBindingIndex] = (void*) inputDataDevice;
+  bindings[outputBindingIndex] = (void*) outputDataDevice;
+
+  /* execute engine */
+  cout << "Executing inference engine..." << endl;
+  const int kBatchSize = 1;
+  context->execute(kBatchSize, bindings);
+
+  /* transfer output back to host */
+  cudaMemcpy(outputDataHost, outputDataDevice, numOutput * sizeof(float), cudaMemcpyDeviceToHost);
+
+  /* parse output */
+  vector<size_t> sortedIndices = argsort(outputDataHost, outputDims);
+
+  cout << "\nThe top-5 indices are: ";
+  for (int i = 0; i < 5; i++)
+    cout << sortedIndices[i] << " ";
+
+  ifstream labelsFile(labelFilename);
+  vector<string> labelMap;
+  string label;
+  while(getline(labelsFile, label))
+  {
+    labelMap.push_back(label);
+  }
+
+  cout << "\nWhich corresponds to class labels: ";
+  for (int i = 0; i < 5; i++)
+    cout << endl << i << ". " << labelMap[sortedIndices[i]];
+  cout << endl;
+
+  /* clean up */
+  runtime->destroy();
+  engine->destroy();
+  context->destroy();
+  free(inputDataHost);
+  free(outputDataHost);
+  cudaFree(inputDataDevice);
+  cudaFree(outputDataDevice);
+
+  return 0;
+}

+ 119 - 0
examples/classify_image/utils.h

@@ -0,0 +1,119 @@
+/**
+ * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+ * Full license terms provided in LICENSE.md file.
+ */
+
+#ifndef TRT_IMAGE_CLASSIFICATION_UTILS_H
+#define TRT_IMAGE_CLASSIFICATION_UTILS_H
+
+
+#include <opencv2/opencv.hpp>
+#include <vector>
+#include <algorithm>
+#include <NvInfer.h>
+
+
+void cvImageToTensor(const cv::Mat & image, float *tensor, nvinfer1::Dims dimensions)
+{
+  const size_t channels = dimensions.d[0];
+  const size_t height = dimensions.d[1];
+  const size_t width = dimensions.d[2];
+  // TODO: validate dimensions match
+  const size_t stridesCv[3] = { width * channels, channels, 1 };
+  const size_t strides[3] = { height * width, width, 1 };
+
+  for (int i = 0; i < height; i++) 
+  {
+    for (int j = 0; j < width; j++) 
+    {
+      for (int k = 0; k < channels; k++) 
+      {
+        const size_t offsetCv = i * stridesCv[0] + j * stridesCv[1] + k * stridesCv[2];
+        const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
+        tensor[offset] = (float) image.data[offsetCv];
+      }
+    }
+  }
+}
+
+void preprocessVgg(float *tensor, nvinfer1::Dims dimensions)
+{
+  size_t channels = dimensions.d[0];
+  size_t height = dimensions.d[1];
+  size_t width = dimensions.d[2];
+  const size_t strides[3] = { height * width, width, 1 };
+  const float mean[3] = { 123.68, 116.78, 103.94 }; // values from TensorFlow slim models code
+
+  for (int i = 0; i < height; i++) 
+  {
+    for (int j = 0; j < width; j++) 
+    {
+      for (int k = 0; k < channels; k++) 
+      {
+        const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
+        tensor[offset] -= mean[k];
+      }
+    }
+  }
+}
+
+void preprocessInception(float *tensor, nvinfer1::Dims dimensions)
+{
+  size_t channels = dimensions.d[0];
+  size_t height = dimensions.d[1];
+  size_t width = dimensions.d[2];
+
+  const size_t numel = channels * height * width;
+  for (int i = 0; i < numel; i++)
+    tensor[i] = 2.0 * (tensor[i] / 255.0 - 0.5); // values from TensorFlow slim models code
+}
+
+int argmax(float *tensor, nvinfer1::Dims dimensions)
+{
+  size_t channels = dimensions.d[0];
+  size_t height = dimensions.d[1];
+  size_t width = dimensions.d[2];
+
+  size_t numel = channels * height * width;
+
+  if (numel <= 0)
+    return 0;
+
+  size_t maxIndex = 0;
+  float max = tensor[0]; 
+  for (int i = 0; i < numel; i++)
+  {
+    if (tensor[i] > max)
+    {
+      maxIndex = i;
+      max = tensor[i];
+    }
+  }
+
+  return maxIndex;
+}
+
+size_t numTensorElements(nvinfer1::Dims dimensions)
+{
+  if (dimensions.nbDims == 0)
+    return 0;
+  size_t size = 1;
+  for (int i = 0; i < dimensions.nbDims; i++)
+    size *= dimensions.d[i];
+  return size;
+}
+
+std::vector<size_t> argsort(float *tensor, nvinfer1::Dims dimensions)
+{
+  size_t numel = numTensorElements(dimensions);
+  std::vector<size_t> indices(numel);
+  for (int i = 0; i < numel; i++)
+    indices[i] = i;
+  std::sort(indices.begin(), indices.begin() + numel, [tensor](size_t idx1, size_t idx2) {
+      return tensor[idx1] > tensor[idx2];
+  });
+
+  return indices;
+}
+
+#endif

+ 69 - 0
scripts/convert_plan.py

@@ -0,0 +1,69 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import os
+import subprocess
+import uff
+import pdb
+import sys
+
+UFF_TO_PLAN_EXE_PATH = 'build/src/uff_to_plan'
+TMP_UFF_FILENAME = 'data/tmp.uff'
+
+
+def frozenToPlan(frozen_graph_filename, plan_filename, input_name, input_height, 
+        input_width, output_name, max_batch_size, max_workspace_size, data_type):
+
+    # generate uff from frozen graph
+    uff_model = uff.from_tensorflow_frozen_model(
+        frozen_file=frozen_graph_filename,
+        output_nodes=[output_name],
+        output_filename=TMP_UFF_FILENAME,
+        text=False
+    )
+
+    # convert frozen graph to engine (plan)
+    args = [
+        TMP_UFF_FILENAME,
+        plan_filename,
+        input_name,
+        str(input_height),
+        str(input_width),
+        output_name,
+        str(max_batch_size),
+        str(max_workspace_size), 
+        data_type # float / half
+    ]
+    subprocess.call([UFF_TO_PLAN_EXE_PATH] + args)
+
+    # cleanup tmp file
+    os.remove(TMP_UFF_FILENAME)
+
+
+if __name__ == '__main__':
+
+    if len(sys.argv) is not 10:
+        print("usage: python convert_plan.py <frozen_graph_path> <output_plan_path> <input_name> <input_height>"
+              " <input_width> <output_name> <max_batch_size> <max_workspace_size> <data_type>")
+        exit()
+
+    frozen_graph_filename = sys.argv[1]
+    plan_filename = sys.argv[2]
+    input_name = sys.argv[3]
+    input_height = sys.argv[4]
+    input_width = sys.argv[5]
+    output_name = sys.argv[6]
+    max_batch_size = sys.argv[7]
+    max_workspace_size = sys.argv[8]
+    data_type = sys.argv[9]
+    
+    frozenToPlan(frozen_graph_filename,
+        plan_filename,
+        input_name,
+        input_height,
+        input_width,
+        output_name,
+        max_batch_size,
+        max_workspace_size,
+        data_type
+    )

+ 83 - 0
scripts/convert_relu6.py

@@ -0,0 +1,83 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import tensorflow as tf
+import sys
+
+
+def makeConst6(const6_name='const6'):
+    graph = tf.Graph()
+    with graph.as_default():
+        tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
+    return graph.as_graph_def()
+
+
+# create base relu6 nodes
+def makeRelu6(output_name, input_name, const6_name='const6'):
+    graph = tf.Graph()
+    with graph.as_default():
+        tf_x = tf.placeholder(tf.float32, [10, 10], name=input_name)
+        tf_6 = tf.constant(dtype=tf.float32, value=6.0, name=const6_name)
+        with tf.name_scope(output_name):
+            tf_y1 = tf.nn.relu(tf_x, name='relu1')
+            tf_y2 = tf.nn.relu(tf.subtract(tf_x, tf_6, name='sub1'), name='relu2')
+
+            #tf_y = tf.nn.relu(tf.subtract(tf_6, tf.nn.relu(tf_x, name='relu1'), name='sub'), name='relu2')
+        #tf_y = tf.subtract(tf_6, tf_y, name=output_name)
+        tf_y = tf.subtract(tf_y1, tf_y2, name=output_name)
+        
+    graph_def = graph.as_graph_def()
+    graph_def.node[-1].name = output_name
+
+    # remove unused nodes
+    for node in graph_def.node:
+        if node.name == input_name:
+            graph_def.node.remove(node)
+    for node in graph_def.node:
+        if node.name == const6_name:
+            graph_def.node.remove(node)
+    for node in graph_def.node:
+        if node.op == '_Neg':
+            node.op = 'Neg'
+            
+    return graph_def
+
+
+def convertRelu6(graph_def, const6_name='const6'):
+    # add constant 6
+    has_const6 = False
+    for node in graph_def.node:
+        if node.name == const6_name:
+            has_const6 = True
+    if not has_const6:
+        const6_graph_def = makeConst6(const6_name=const6_name)
+        graph_def.node.extend(const6_graph_def.node)
+        
+    for node in graph_def.node:
+        if node.op == 'Relu6':
+            input_name = node.input[0]
+            output_name = node.name
+            relu6_graph_def = makeRelu6(output_name, input_name, const6_name=const6_name)
+            graph_def.node.remove(node)
+            graph_def.node.extend(relu6_graph_def.node)
+            
+    return graph_def
+
+
+if __name__ == '__main__':
+
+    if len(sys.argv) != 3:
+        print("Replaces Relu6 nodes in a frozen graph with Relu(x) - Relu(x-6)\n")
+        print("Usage: python convert_relu6.py <frozen_graph_path> <output_frozen_graph_path>")
+        exit()
+
+    with open(sys.argv[1], 'rb') as f:
+        graph_def = tf.GraphDef()
+        graph_def.ParseFromString(f.read())
+
+    converted_graph_def = convertRelu6(graph_def)
+
+    with open(sys.argv[2], 'wb') as f:
+        f.write(converted_graph_def.SerializeToString())
+
+    f.close()

+ 25 - 0
scripts/download_images.sh

@@ -0,0 +1,25 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+image_urls=(
+  http://farm3.static.flickr.com/2017/2496831224_221cd963a2.jpg
+  http://farm3.static.flickr.com/2582/4106642219_190bf0f817.jpg
+  http://farm4.static.flickr.com/3226/2719028129_9aa2e27675.jpg
+)
+
+image_names=(
+  gordon_setter.jpg
+  lifeboat.jpg
+  golden_retriever.jpg
+)
+
+image_folder=data/images
+
+mkdir -p $image_folder
+
+for i in ${!image_urls[@]}
+do
+  echo ${image_urls[$i]}
+  echo ${image_names[$i]}
+  wget -O $image_folder/${image_names[$i]} ${image_urls[$i]}
+done

+ 36 - 0
scripts/download_models.sh

@@ -0,0 +1,36 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+model_urls=(
+  http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 
+  http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz 
+  http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 
+  http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz 
+  http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 
+  http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 
+  http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz 
+  http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz 
+  http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
+  http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz
+  http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz
+  http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz
+  http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz
+  http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz
+  http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz
+  http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz
+  http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz
+)
+
+mkdir -p data/checkpoints
+cd data/checkpoints
+
+for url in ${model_urls[@]}
+do
+  tarname=$(basename $url)
+  echo $tarname
+  wget $url
+  tar -xzf $(basename $url)
+  rm $tarname
+done
+
+cd ../..

+ 36 - 0
scripts/frozen_graphs_to_plans.py

@@ -0,0 +1,36 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import sys
+sys.path.append('third_party/models/')
+sys.path.append('third_party/models/research')
+sys.path.append('third_party/models/research/slim')
+import uff
+from model_meta import NETS, FROZEN_GRAPHS_DIR, CHECKPOINT_DIR, PLAN_DIR
+from convert_plan import frozenToPlan
+import os
+
+
+if not os.path.exists('data/plans'):
+    os.makedirs('data/plans')
+
+
+if __name__ == '__main__':
+
+    for net_name, net_meta in NETS.items():
+        
+        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
+            continue
+
+        print("Convertings %s to PLAN" % net_name)
+         
+        frozenToPlan(net_meta['frozen_graph_filename'],
+            net_meta['plan_filename'],
+            net_meta['input_name'],
+            net_meta['input_height'],
+            net_meta['input_width'],
+            net_meta['output_names'][0],
+            1, # batch size
+            1 << 20, # workspace size
+            'half' # data type
+        )

+ 26 - 0
scripts/frozen_graphs_to_uffs.py

@@ -0,0 +1,26 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import sys
+sys.path.append('third_party/models/')
+sys.path.append('third_party/models/research')
+sys.path.append('third_party/models/research/slim')
+import uff
+from model_meta import NETS, FROZEN_GRAPHS_DIR, CHECKPOINT_DIR, UFF_DIR
+
+
+if __name__ == '__main__':
+
+    for net_name, net_meta in NETS.items():
+        
+        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
+            continue
+
+        print("Convertings %s to UFF" % net_name)
+        
+        uff_model = uff.from_tensorflow_frozen_model(
+            frozen_file=net_meta['frozen_graph_filename'],
+            output_nodes=net_meta['output_names'],
+            output_filename=net_meta['uff_filename'],
+            text=False
+        )

+ 7 - 0
scripts/make_swap.sh

@@ -0,0 +1,7 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+sudo fallocate -l 1G /mnt/swapfile
+sudo chmod 600 /mnt/swapfile
+sudo mkswap /mnt/swapfile
+sudo swapon /mnt/swapfile

+ 341 - 0
scripts/model_meta.py

@@ -0,0 +1,341 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import numpy as np
+import sys
+sys.path.append("third_party/models/research/")
+sys.path.append("third_party/models")
+sys.path.append("third_party/")
+sys.path.append("third_party/models/research/slim/")
+import tensorflow.contrib.slim as tf_slim
+import slim.nets as nets
+import slim.nets.vgg
+import slim.nets.inception
+import slim.nets.resnet_v1
+import slim.nets.resnet_v2
+import slim.nets.mobilenet_v1
+
+
+def create_label_map(label_file='data/imagenet_labels_1001.txt'):
+    label_map = {}
+    with open(label_file, 'r') as f:
+        labels = f.readlines()
+        for i, label in enumerate(labels):
+            label_map[i] = label
+    return label_map
+        
+
+IMAGNET2012_LABEL_MAP = create_label_map()
+
+
+def preprocess_vgg(image):
+    return np.array(image, dtype=np.float32) - np.array([123.68, 116.78, 103.94])
+
+def postprocess_vgg(output):
+    output = output.flatten()
+    predictions_top5 = np.argsort(output)[::-1][0:5]
+    labels_top5 = [IMAGNET2012_LABEL_MAP[p + 1] for p in predictions_top5]
+    return labels_top5
+
+def preprocess_inception(image):
+    return 2.0 * (np.array(image, dtype=np.float32) / 255.0 - 0.5)
+
+def postprocess_inception(output):
+    output = output.flatten()
+    predictions_top5 = np.argsort(output)[::-1][0:5]
+    labels_top5 = [IMAGNET2012_LABEL_MAP[p] for p in predictions_top5]
+    return labels_top5
+def mobilenet_v1_1p0_224(*args, **kwargs):
+    kwargs['depth_multiplier'] = 1.0
+    return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)
+
+def mobilenet_v1_0p5_160(*args, **kwargs):
+    kwargs['depth_multiplier'] = 0.5
+    return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)
+
+def mobilenet_v1_0p25_128(*args, **kwargs):
+    kwargs['depth_multiplier'] = 0.25
+    return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)   
+
+
+CHECKPOINT_DIR = 'data/checkpoints/'
+FROZEN_GRAPHS_DIR = 'data/frozen_graphs/'
+# UFF_DIR = 'data/uff/'
+PLAN_DIR = 'data/plans/'
+
+
+NETS = {
+
+    'vgg_16': {
+        'model': nets.vgg.vgg_16,
+        'arg_scope': nets.vgg.vgg_arg_scope,
+        'num_classes': 1000,
+        'input_name': 'input',
+        'output_names': ['vgg_16/fc8/BiasAdd'],
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3, 
+        'preprocess_fn': preprocess_vgg,
+        'postprocess_fn': postprocess_vgg,
+        'checkpoint_filename': CHECKPOINT_DIR + 'vgg_16.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_16.pb',
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'vgg_16.plan'
+    },
+
+    'vgg_19': {
+        'model': nets.vgg.vgg_19,
+        'arg_scope': nets.vgg.vgg_arg_scope,
+        'num_classes': 1000,
+        'input_name': 'input',
+        'output_names': ['vgg_19/fc8/BiasAdd'],
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3, 
+        'preprocess_fn': preprocess_vgg,
+        'postprocess_fn': postprocess_vgg,
+        'checkpoint_filename': CHECKPOINT_DIR + 'vgg_19.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_19.pb',
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'vgg_19.plan',
+        'exclude': True
+    },
+
+    'inception_v1': {
+        'model': nets.inception.inception_v1,
+        'arg_scope': nets.inception.inception_v1_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['InceptionV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v1.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v1.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'inception_v1.plan'
+    },
+
+    'inception_v2': {
+        'model': nets.inception.inception_v2,
+        'arg_scope': nets.inception.inception_v2_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['InceptionV2/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v2.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v2.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'trt_convert_status': "bad results",
+        'plan_filename': PLAN_DIR + 'inception_v2.plan'
+    },
+
+    'inception_v3': {
+        'model': nets.inception.inception_v3,
+        'arg_scope': nets.inception.inception_v3_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionV3/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v3.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v3.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'inception_v3.plan'
+    },
+
+    'inception_v4': {
+        'model': nets.inception.inception_v4,
+        'arg_scope': nets.inception.inception_v4_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionV4/Logits/Logits/BiasAdd'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v4.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v4.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'inception_v4.plan'
+    },
+    
+    'inception_resnet_v2': {
+        'model': nets.inception.inception_resnet_v2,
+        'arg_scope': nets.inception.inception_resnet_v2_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionResnetV2/Logits/Logits/BiasAdd'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_resnet_v2_2016_08_30.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_resnet_v2.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'trt_convert_status': "works",
+        'plan_filename': PLAN_DIR + 'inception_resnet_v2.plan'
+    },
+
+    'resnet_v1_50': {
+        'model': nets.resnet_v1.resnet_v1_50,
+        'arg_scope': nets.resnet_v1.resnet_arg_scope,
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_50/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_50.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_50.pb',
+        'preprocess_fn': preprocess_vgg,
+        'postprocess_fn': postprocess_vgg,
+        'plan_filename': PLAN_DIR + 'resnet_v1_50.plan'
+    },
+
+    'resnet_v1_101': {
+        'model': nets.resnet_v1.resnet_v1_101,
+        'arg_scope': nets.resnet_v1.resnet_arg_scope,
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_101/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_101.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_101.pb',
+        'preprocess_fn': preprocess_vgg,
+        'postprocess_fn': postprocess_vgg,
+        'plan_filename': PLAN_DIR + 'resnet_v1_101.plan'
+    },
+
+    'resnet_v1_152': {
+        'model': nets.resnet_v1.resnet_v1_152,
+        'arg_scope': nets.resnet_v1.resnet_arg_scope,
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_152/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_152.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_152.pb',
+        'preprocess_fn': preprocess_vgg,
+        'postprocess_fn': postprocess_vgg,
+        'plan_filename': PLAN_DIR + 'resnet_v1_152.plan'
+    },
+
+    'resnet_v2_50': {
+        'model': nets.resnet_v2.resnet_v2_50,
+        'arg_scope': nets.resnet_v2.resnet_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_50/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_50.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_50.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'plan_filename': PLAN_DIR + 'resnet_v2_50.plan'
+    },
+
+    'resnet_v2_101': {
+        'model': nets.resnet_v2.resnet_v2_101,
+        'arg_scope': nets.resnet_v2.resnet_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_101/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_101.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_101.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'plan_filename': PLAN_DIR + 'resnet_v2_101.plan'
+    },
+
+    'resnet_v2_152': {
+        'model': nets.resnet_v2.resnet_v2_152,
+        'arg_scope': nets.resnet_v2.resnet_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_152/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_152.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_152.pb',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+        'plan_filename': PLAN_DIR + 'resnet_v2_152.plan'
+    },
+
+    #'resnet_v2_200': {
+
+    #},
+
+    'mobilenet_v1_1p0_224': {
+        'model': mobilenet_v1_1p0_224,
+        'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_1.0_224.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_1p0_224.pb',
+        'plan_filename': PLAN_DIR + 'mobilenet_v1_1p0_224.plan',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+    },
+
+    'mobilenet_v1_0p5_160': {
+        'model': mobilenet_v1_0p5_160,
+        'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 160,
+        'input_height': 160,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_0.50_160.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p5_160.pb',
+        'plan_filename': PLAN_DIR + 'mobilenet_v1_0p5_160.plan',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+    },
+
+    'mobilenet_v1_0p25_128': {
+        'model': mobilenet_v1_0p25_128,
+        'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 128,
+        'input_height': 128,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_0.25_128.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p25_128.pb',
+        'plan_filename': PLAN_DIR + 'mobilenet_v1_0p25_128.plan',
+        'preprocess_fn': preprocess_inception,
+        'postprocess_fn': postprocess_inception,
+    },
+}
+
+

+ 233 - 0
scripts/model_meta_nodeps.py

@@ -0,0 +1,233 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+CHECKPOINT_DIR = 'data/checkpoints/'
+FROZEN_GRAPHS_DIR = 'data/frozen_graphs/'
+UFF_DIR = 'data/uff/'
+
+NETS = {
+
+    'vgg_16': {
+        'num_classes': 1000,
+        'input_name': 'input',
+        'output_names': ['vgg_16/fc8/BiasAdd'],
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3, 
+        'preprocess_fn': 'preprocess_vgg',
+        'checkpoint_filename': CHECKPOINT_DIR + 'vgg_16.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_16.pb',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'vgg_16.uff'
+    },
+
+    'vgg_19': {
+        'num_classes': 1000,
+        'input_name': 'input',
+        'output_names': ['vgg_19/fc8/BiasAdd'],
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3, 
+        'preprocess_fn': 'preprocess_vgg',
+        'checkpoint_filename': CHECKPOINT_DIR + 'vgg_19.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_19.pb',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'vgg_19.uff'
+    },
+
+    'inception_v1': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['InceptionV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v1.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v1.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'inception_v1.uff'
+    },
+
+    'inception_v2': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['InceptionV2/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v2.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v2.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'trt_convert_status': "bad results",
+        'uff_filename': UFF_DIR + 'inception_v2.uff'
+    },
+
+    'inception_v3': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionV3/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v3.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v3.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'inception_v3.uff'
+    },
+
+    'inception_v4': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionV4/Logits/Logits/BiasAdd'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_v4.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v4.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'inception_v4.uff'
+    },
+    
+    'inception_resnet_v2': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['InceptionResnetV2/Logits/Logits/BiasAdd'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'inception_resnet_v2_2016_08_30.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_resnet_v2.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'trt_convert_status': "works",
+        'uff_filename': UFF_DIR + 'inception_resnet_v2.uff'
+    },
+
+    'resnet_v1_50': {
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_50/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_50.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_50.pb',
+        'preprocess_fn': 'preprocess_vgg',
+        'uff_filename': UFF_DIR + 'resnet_v1_50.uff'
+    },
+
+    'resnet_v1_101': {
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_101/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_101.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_101.pb',
+        'preprocess_fn': 'preprocess_vgg',
+        'uff_filename': UFF_DIR + 'resnet_v1_101.uff'
+    },
+
+    'resnet_v1_152': {
+        'num_classes': 1000,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['resnet_v1_152/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_152.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_152.pb',
+        'preprocess_fn': 'preprocess_vgg',
+        'uff_filename': UFF_DIR + 'resnet_v1_152.uff'
+    },
+
+    'resnet_v2_50': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_50/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_50.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_50.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'uff_filename': UFF_DIR + 'resnet_v2_50.uff'
+    },
+
+    'resnet_v2_101': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_101/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_101.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_101.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'uff_filename': UFF_DIR + 'resnet_v2_101.uff'
+    },
+
+    'resnet_v2_152': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 299,
+        'input_height': 299,
+        'input_channels': 3,
+        'output_names': ['resnet_v2_152/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_152.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_152.pb',
+        'preprocess_fn': 'preprocess_inception',
+        'uff_filename': UFF_DIR + 'resnet_v2_152.uff'
+    },
+
+    #'resnet_v2_200': {
+
+    #},
+
+    'mobilenet_v1_1p0_224': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 224,
+        'input_height': 224,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_1.0_224.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_1p0_224.pb',
+        'uff_filename': UFF_DIR + 'mobilenet_v1_1p0_224.uff',
+        'preprocess_fn': 'preprocess_inception',
+    },
+
+    'mobilenet_v1_0p5_160': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 160,
+        'input_height': 160,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_0.50_160.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p5_160.pb',
+        'uff_filename': UFF_DIR + 'mobilenet_v1_0p5_160.uff',
+        'preprocess_fn': 'preprocess_inception',
+    },
+
+    'mobilenet_v1_0p25_128': {
+        'num_classes': 1001,
+        'input_name': 'input',
+        'input_width': 128,
+        'input_height': 128,
+        'input_channels': 3,
+        'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
+        'checkpoint_filename': CHECKPOINT_DIR + 
+            'mobilenet_v1_0.25_128.ckpt',
+        'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p25_128.pb',
+        'uff_filename': UFF_DIR + 'mobilenet_v1_0p25_128.uff',
+        'preprocess_fn': 'preprocess_inception',
+    },
+}
+
+

+ 82 - 0
scripts/models_to_frozen_graphs.py

@@ -0,0 +1,82 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import tensorflow as tf
+import sys
+sys.path.append("third_party/models/research/")
+sys.path.append("third_party/models")
+sys.path.append("third_party/")
+sys.path.append("third_party/models/research/slim/")
+sys.path.append("scripts")
+import tensorflow.contrib.slim as tf_slim
+import slim.nets as nets
+import slim.nets.vgg
+from model_meta import NETS, CHECKPOINT_DIR, FROZEN_GRAPHS_DIR
+from convert_relu6 import convertRelu6
+import os
+
+
+
+if __name__ == '__main__':
+
+    if not os.path.exists(CHECKPOINT_DIR):
+        print("%s does not exist.  Exiting." % CHECKPOINT_DIR)
+        exit()
+
+    if not os.path.exists(FROZEN_GRAPHS_DIR):
+        print("%s does not exist.  Creating it now." % FROZEN_GRAPHS_DIR)
+        os.makedirs(FROZEN_GRAPHS_DIR)
+
+    for net_name, net_meta in NETS.items():
+
+        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
+            continue
+
+        print("Converting %s" % net_name)
+        print(net_meta)
+
+        tf.reset_default_graph()
+        tf_config = tf.ConfigProto()
+        tf_config.gpu_options.allow_growth = True
+        tf_sess = tf.Session(config=tf_config)
+        tf_input = tf.placeholder(
+            tf.float32, 
+            (
+                None, 
+                net_meta['input_height'], 
+                net_meta['input_width'], 
+                net_meta['input_channels']
+            ),
+            name=net_meta['input_name']
+        )
+
+        with tf_slim.arg_scope(net_meta['arg_scope']()):
+            tf_net, tf_end_points = net_meta['model'](
+                tf_input, 
+                is_training=False,
+                num_classes=net_meta['num_classes']
+            )
+
+        tf_saver = tf.train.Saver()
+        tf_saver.restore(
+            save_path=net_meta['checkpoint_filename'], 
+            sess=tf_sess
+        )
+        frozen_graph = tf.graph_util.convert_variables_to_constants(
+            tf_sess,
+            tf_sess.graph_def,
+            output_node_names=net_meta['output_names']
+        )
+
+        frozen_graph = convertRelu6(frozen_graph)
+
+        with open(net_meta['frozen_graph_filename'], 'wb') as f:
+            f.write(frozen_graph.SerializeToString())
+    
+        f.close()
+
+        del tf_config
+        del tf_sess
+        del tf_input
+        del tf_saver
+        del frozen_graph

+ 43 - 0
scripts/print_slim_output.py

@@ -0,0 +1,43 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import tensorflow as tf
+import sys
+sys.path.append("third_party/models/research/")
+sys.path.append("third_party/models")
+sys.path.append("third_party/")
+sys.path.append("third_party/models/research/slim/")
+sys.path.append("scripts")
+import tensorflow.contrib.slim as tf_slim
+import slim.nets as nets
+import slim.nets.vgg
+from model_meta import NETS
+
+
+if __name__ == '__main__':
+
+    with open("data/output_names.txt", 'w') as f:
+        for net_name, net_meta in NETS.items():
+
+            tf.reset_default_graph()
+            tf_sess = tf.Session()
+            tf_input = tf.placeholder(
+                tf.float32, 
+                (
+                    None, 
+                    net_meta['input_height'], 
+                    net_meta['input_width'], 
+                    net_meta['input_channels']
+                ),
+                name=net_meta['input_name']
+            )
+
+            with tf_slim.arg_scope(net_meta['arg_scope']()):
+                tf_net, tf_end_points = net_meta['model'](
+                    tf_input, 
+                    is_training=False,
+                    num_classes=net_meta['num_classes']
+                )
+                print("Output name for %s is %s" % (net_name, tf_net.name))
+            f.write("%s\t%s\n" % (net_name, tf_net.name))
+        f.close()

+ 63 - 0
scripts/test_tf.py

@@ -0,0 +1,63 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import sys
+sys.path.append('third_party/models/')
+sys.path.append('third_party/models/research')
+sys.path.append('third_party/models/research/slim')
+#from PIL import Image
+#import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+from model_meta import NETS, FROZEN_GRAPHS_DIR, CHECKPOINT_DIR
+import time
+import cv2
+
+
+TEST_IMAGE_PATH='data/images/gordon_setter.jpg'
+TEST_OUTPUT_PATH='data/test_output_tf.txt'
+NUM_RUNS=50
+
+if __name__ == '__main__':
+
+    with open(TEST_OUTPUT_PATH, 'w') as test_f:
+        for net_name, net_meta in NETS.items():
+            print("Testing %s" % net_name)
+
+            with open(net_meta['frozen_graph_filename'], 'rb') as f:
+                graph_def = tf.GraphDef()
+                graph_def.ParseFromString(f.read())
+
+            with tf.Graph().as_default() as graph:
+                tf.import_graph_def(graph_def, name="")
+            
+            tf_config = tf.ConfigProto()
+            tf_config.gpu_options.allow_growth = True
+            tf_config.allow_soft_placement = True
+
+            tf_sess = tf.Session(config=tf_config, graph=graph)
+            tf_input = tf_sess.graph.get_tensor_by_name(net_meta['input_name'] + ':0')
+            tf_output = tf_sess.graph.get_tensor_by_name(net_meta['output_names'][0] + ':0')
+
+            # load and preprocess image
+            image = cv2.imread(TEST_IMAGE_PATH)
+            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+            image = cv2.resize(image, (net_meta['input_width'], net_meta['input_height']))
+            image = net_meta['preprocess_fn'](image)
+
+
+            # run network
+            times = []
+            for i in range(NUM_RUNS + 1):
+                t0 = time.time()
+                output = tf_sess.run([tf_output], feed_dict={
+                    tf_input: image[None, ...]
+                })[0]
+                t1 = time.time()
+                times.append(1000 * (t1 - t0))
+            avg_time = np.mean(times[1:]) # don't include first run
+
+            # parse output
+            top5 = net_meta['postprocess_fn'](output)
+            print(top5)
+            test_f.write("%s %s\n" % (net_name, avg_time))

+ 43 - 0
scripts/test_trt.py

@@ -0,0 +1,43 @@
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# Full license terms provided in LICENSE.md file.
+
+import sys
+#sys.path.append("../scripts")
+#sys.path.append(".")
+from model_meta import NETS
+import os
+import subprocess
+import pdb
+
+TEST_IMAGE_PATH='data/images/gordon_setter.jpg'
+TEST_OUTPUT_PATH='data/test_output_trt.txt'
+TEST_EXE_PATH='./build/src/test/test_trt'
+
+if __name__ == '__main__':
+    
+    # delete output file 
+    if os.path.isfile(TEST_OUTPUT_PATH):
+       os.remove(TEST_OUTPUT_PATH)
+
+    for net_name, net_meta in NETS.items():
+        if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
+            continue
+
+        args = [
+            TEST_IMAGE_PATH,
+            net_meta['plan_filename'],
+            net_meta['input_name'],
+            str(net_meta['input_height']),
+            str(net_meta['input_width']),
+            net_meta['output_names'][0],
+            str(net_meta['num_classes']), 
+            net_meta['preprocess_fn'].__name__,
+            str(50), # numRuns
+            "half", # dataType 
+            str(1), # maxBatchSize 
+            str(1 << 20), # workspaceSize 
+            str(0), # useMappedMemory 
+            TEST_OUTPUT_PATH
+        ]
+        print("Running %s" % net_name)
+        subprocess.call([TEST_EXE_PATH] + args)

+ 4 - 0
src/CMakeLists.txt

@@ -0,0 +1,4 @@
+add_subdirectory(test)
+
+add_executable(uff_to_plan uff_to_plan.cpp)
+target_link_libraries(uff_to_plan nvinfer nvparsers)

+ 2 - 0
src/test/CMakeLists.txt

@@ -0,0 +1,2 @@
+cuda_add_executable(test_trt test_trt.cu)
+target_link_libraries(test_trt ${OpenCV_LIBS} nvinfer nvparsers)

+ 353 - 0
src/test/test_trt.cu

@@ -0,0 +1,353 @@
+/**
+ * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+ * Full license terms provided in LICENSE.md file.
+ */
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <sstream>
+#include <chrono>
+#include <stdexcept>
+#include <fstream>
+
+#include <opencv2/opencv.hpp>
+#include <NvInfer.h>
+
+
+#define MS_PER_SEC 1000.0
+
+
+using namespace std;
+using namespace nvinfer1;
+
+class TestConfig;
+
+typedef void (*preprocess_fn_t)(float *input, size_t channels, size_t height, size_t width);
+float * imageToTensor(const cv::Mat & image);
+void preprocessVgg(float *input, size_t channels, size_t height, size_t width);
+void preprocessInception(float *input, size_t channels, size_t height, size_t width);
+size_t argmax(float *input, size_t numel);
+void test(const TestConfig &testConfig);
+
+class TestConfig
+{
+public:
+  string imagePath;
+  string planPath;
+  string inputNodeName;
+  string outputNodeName;
+  string preprocessFnName;
+  string inputHeight;
+  string inputWidth;
+  string numOutputCategories;
+  string dataType;
+  string maxBatchSize;
+  string workspaceSize;
+  string numRuns;
+  string useMappedMemory;
+  string statsPath;
+  
+  TestConfig(int argc, char * argv[])
+  {
+    imagePath = argv[1];
+    planPath = argv[2];
+    inputNodeName = argv[3];
+    inputHeight = argv[4];
+    inputWidth = argv[5];
+    outputNodeName = argv[6];
+    numOutputCategories = argv[7];
+    preprocessFnName = argv[8];
+    numRuns = argv[9];
+    dataType = argv[10];
+    maxBatchSize = argv[11];
+    workspaceSize = argv[12];
+    useMappedMemory = argv[13];
+    statsPath = argv[14];
+  }
+
+  static string UsageString()
+  {
+    string s = "";
+    s += "imagePath: \n";
+    s += "planPath: \n";
+    s += "inputNodeName: \n";
+    s += "inputHeight: \n";
+    s += "inputWidth: \n";
+    s += "outputNodeName: \n";
+    s += "numOutputCategories: \n";
+    s += "preprocessFnName: \n";
+    s += "numRuns: \n";
+    s += "dataType: \n";
+    s += "maxBatchSize: \n";
+    s += "workspaceSize: \n";
+    s += "useMappedMemory: \n";
+    s += "statsPath: \n";
+    return s;
+
+  }
+
+  string ToString()
+  {
+    string s = "";
+    s += "imagePath: " + imagePath + "\n";
+    s += "planPath: " + planPath + "\n";
+    s += "inputNodeName: " + inputNodeName + "\n";
+    s += "inputHeight: " + inputHeight + "\n";
+    s += "inputWidth: " + inputWidth + "\n";
+    s += "outputNodeName: " + outputNodeName + "\n";
+    s += "numOutputCategories: " + numOutputCategories + "\n";
+    s += "preprocessFnName: " + preprocessFnName + "\n";
+    s += "numRuns: " + numRuns + "\n";
+    s += "dataType: " + dataType + "\n";
+    s += "maxBatchSize: " + maxBatchSize + "\n";
+    s += "workspaceSize: " + workspaceSize + "\n";
+    s += "useMappedMemory: " + useMappedMemory + "\n";
+    s += "statsPath: " + statsPath + "\n";
+    return s;
+  }
+
+  static int ToInteger(string value)
+  {
+    int valueInt;
+    stringstream ss;
+    ss << value;
+    ss >> valueInt;
+    return valueInt;
+  }
+
+  preprocess_fn_t PreprocessFn() const {
+    if (preprocessFnName == "preprocess_vgg")
+       return preprocessVgg;
+    else if (preprocessFnName == "preprocess_inception")
+       return preprocessInception;
+    else
+       throw runtime_error("Invalid preprocessing function name.");
+  }
+
+  int InputWidth() const { return ToInteger(inputWidth); }
+  int InputHeight() const { return ToInteger(inputHeight); }
+  int NumOutputCategories() const { return ToInteger(numOutputCategories); }
+
+  nvinfer1::DataType DataType() const { 
+    if (dataType == "float")
+      return nvinfer1::DataType::kFLOAT;
+    else if (dataType == "half")
+      return nvinfer1::DataType::kHALF;
+    else
+      throw runtime_error("Invalid data type.");
+  }
+  
+  int MaxBatchSize() const { return ToInteger(maxBatchSize); }
+  int WorkspaceSize() const { return ToInteger(workspaceSize); }
+  int NumRuns() const { return ToInteger(numRuns); } 
+  int UseMappedMemory() const { return ToInteger(useMappedMemory); } 
+};
+
+
+class Logger : public ILogger
+{
+  void log(Severity severity, const char * msg) override
+  {
+      cout << msg << endl;
+  }
+} gLogger;
+
+
+int main(int argc, char * argv[])
+{
+
+  if (argc != 15)
+  {
+    cout << TestConfig::UsageString() << endl;
+    return 0;
+  }
+
+  TestConfig testConfig(argc, argv); 
+  cout << "\ntestConfig: \n" << testConfig.ToString() << endl;
+
+  test(testConfig);
+
+  return 0;
+}
+
+
+float *imageToTensor(const cv::Mat & image)
+{
+  const size_t height = image.rows;
+  const size_t width = image.cols;
+  const size_t channels = image.channels();
+  const size_t numel = height * width * channels;
+
+  const size_t stridesCv[3] = { width * channels, channels, 1 };
+  const size_t strides[3] = { height * width, width, 1 };
+
+  float * tensor;
+  cudaHostAlloc((void**)&tensor, numel * sizeof(float), cudaHostAllocMapped);
+
+  for (int i = 0; i < height; i++) 
+  {
+    for (int j = 0; j < width; j++) 
+    {
+      for (int k = 0; k < channels; k++) 
+      {
+        const size_t offsetCv = i * stridesCv[0] + j * stridesCv[1] + k * stridesCv[2];
+        const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
+        tensor[offset] = (float) image.data[offsetCv];
+      }
+    }
+  }
+
+  return tensor;
+}
+
+
+void preprocessVgg(float * tensor, size_t channels, size_t height, size_t width)
+{
+  const size_t strides[3] = { height * width, width, 1 };
+  const float mean[3] = { 123.68, 116.78, 103.94 };
+
+  for (int i = 0; i < height; i++) 
+  {
+    for (int j = 0; j < width; j++) 
+    {
+      for (int k = 0; k < channels; k++) 
+      {
+        const size_t offset = k * strides[0] + i * strides[1] + j * strides[2];
+        tensor[offset] -= mean[k];
+      }
+    }
+  }
+}
+
+
+void preprocessInception(float * tensor, size_t channels, size_t height, size_t width)
+{
+  const size_t numel = channels * height * width;
+  for (int i = 0; i < numel; i++)
+    tensor[i] = 2.0 * (tensor[i] / 255.0 - 0.5);
+}
+
+
+size_t argmax(float * tensor, size_t numel)
+{
+  if (numel <= 0)
+    return 0;
+
+  size_t maxIndex = 0;
+  float max = tensor[0]; 
+  for (int i = 0; i < numel; i++)
+  {
+    if (tensor[i] > max)
+    {
+      maxIndex = i;
+      max = tensor[i];
+    }
+  }
+
+  return maxIndex;
+}
+
+
+void test(const TestConfig &testConfig)
+{
+  ifstream planFile(testConfig.planPath);
+  stringstream planBuffer;
+  planBuffer << planFile.rdbuf();
+  string plan = planBuffer.str();
+  IRuntime *runtime = createInferRuntime(gLogger);
+  ICudaEngine *engine = runtime->deserializeCudaEngine((void*)plan.data(),
+      plan.size(), nullptr);
+  IExecutionContext *context = engine->createExecutionContext();
+
+  int inputBindingIndex, outputBindingIndex;
+  inputBindingIndex = engine->getBindingIndex(testConfig.inputNodeName.c_str());
+  outputBindingIndex = engine->getBindingIndex(testConfig.outputNodeName.c_str());
+
+  // load and preprocess image
+  cv::Mat image = cv::imread(testConfig.imagePath, CV_LOAD_IMAGE_COLOR);
+  cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
+  cv::resize(image, image, cv::Size(testConfig.InputWidth(), testConfig.InputHeight()));
+  float *input = imageToTensor(image);
+  testConfig.PreprocessFn()(input, 3, testConfig.InputHeight(), testConfig.InputWidth());
+
+  // allocate memory on host / device for input / output
+  float *output;
+  float *inputDevice;
+  float *outputDevice;
+  size_t inputSize = testConfig.InputHeight() * testConfig.InputWidth() * 3 * sizeof(float);
+
+  cudaHostAlloc(&output, testConfig.NumOutputCategories() * sizeof(float), cudaHostAllocMapped);
+
+  if (testConfig.UseMappedMemory())
+  {
+    cudaHostGetDevicePointer(&inputDevice, input, 0);
+    cudaHostGetDevicePointer(&outputDevice, output, 0);
+  }
+  else
+  {
+    cudaMalloc(&inputDevice, inputSize);
+    cudaMalloc(&outputDevice, testConfig.NumOutputCategories() * sizeof(float));
+  }
+
+  float *bindings[2];
+  bindings[inputBindingIndex] = inputDevice;
+  bindings[outputBindingIndex] = outputDevice;
+
+  // run and compute average time over numRuns iterations
+  double avgTime = 0;
+  for (int i = 0; i < testConfig.NumRuns() + 1; i++)
+  {
+    chrono::duration<double> diff;
+
+    if (testConfig.UseMappedMemory())
+    {
+      auto t0 = chrono::steady_clock::now();
+      context->execute(1, (void**)bindings);
+      auto t1 = chrono::steady_clock::now();
+      diff = t1 - t0;
+    } 
+    else 
+    {
+      auto t0 = chrono::steady_clock::now();
+      cudaMemcpy(inputDevice, input, inputSize, cudaMemcpyHostToDevice);
+      context->execute(1, (void**)bindings);
+      cudaMemcpy(output, outputDevice, testConfig.NumOutputCategories() * sizeof(float), cudaMemcpyDeviceToHost);
+      auto t1 = chrono::steady_clock::now();
+      diff = t1 - t0;
+    }
+
+
+    if (i != 0)
+      avgTime += MS_PER_SEC * diff.count();
+  }
+  avgTime /= testConfig.NumRuns();
+
+  // save results to file
+  int maxCategoryIndex = argmax(output, testConfig.NumOutputCategories()) + 1001 - testConfig.NumOutputCategories();
+  cout << "Most likely category id is " << maxCategoryIndex << endl;
+  cout << "Average execution time in ms is " << avgTime << endl;
+  ofstream outfile;
+  outfile.open(testConfig.statsPath, ios_base::app);
+  outfile << "\n" << testConfig.planPath 
+    << " " << avgTime;
+    // << " " << maxCategoryIndex
+    // << " " << testConfig.InputWidth() 
+    // << " " << testConfig.InputHeight()
+    // << " " << testConfig.MaxBatchSize() 
+    // << " " << testConfig.WorkspaceSize() 
+    // << " " << testConfig.dataType 
+    // << " " << testConfig.NumRuns() 
+    // << " " << testConfig.UseMappedMemory();
+  outfile.close();
+
+  cudaFree(inputDevice);
+  cudaFree(outputDevice);
+
+  cudaFreeHost(input);
+  cudaFreeHost(output);
+
+  engine->destroy();
+  context->destroy();
+  runtime->destroy();
+}

+ 105 - 0
src/uff_to_plan.cpp

@@ -0,0 +1,105 @@
+/**
+ * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+ * Full license terms provided in LICENSE.md file.
+ */
+
+#include <iostream>
+#include <string>
+#include <sstream>
+#include <fstream>
+
+#include <NvInfer.h>
+#include <NvUffParser.h>
+
+
+using namespace std;
+using namespace nvinfer1;
+using namespace nvuffparser;
+
+
+class Logger : public ILogger
+{
+  void log(Severity severity, const char * msg) override
+  {
+      cout << msg << endl;
+  }
+} gLogger;
+
+int toInteger(string value)
+{
+  int valueInteger;
+  stringstream ss;
+  ss << value;
+  ss >> valueInteger;
+  return valueInteger;
+}
+
+DataType toDataType(string value)
+{
+  if (value == "float")
+    return DataType::kFLOAT;
+  else if (value == "half")
+    return DataType::kHALF;
+  else
+    throw runtime_error("Unsupported data type");
+}
+
+int main(int argc, char *argv[])
+{
+  if (argc != 10)
+  {
+    cout << "Usage: <uff_filename> <plan_filename> <input_name> <input_height> <input_width>"
+      << " <output_name> <max_batch_size> <max_workspace_size> <data_type>\n";
+    return 1;
+  }
+
+  /* parse command line arguments */
+  string uffFilename = argv[1];
+  string planFilename = argv[2];
+  string inputName = argv[3];
+  int inputHeight = toInteger(argv[4]);
+  int inputWidth = toInteger(argv[5]);
+  string outputName = argv[6];
+  int maxBatchSize = toInteger(argv[7]);
+  int maxWorkspaceSize = toInteger(argv[8]);
+  DataType dataType = toDataType(argv[9]);
+
+  /* parse uff */
+  IBuilder *builder = createInferBuilder(gLogger);
+  INetworkDefinition *network = builder->createNetwork();
+  IUffParser *parser = createUffParser();
+  parser->registerInput(inputName.c_str(), DimsCHW(3, inputHeight, inputWidth));
+  parser->registerOutput(outputName.c_str());
+  if (!parser->parse(uffFilename.c_str(), *network, dataType))
+  {
+    cout << "Failed to parse UFF\n";
+    builder->destroy();
+    parser->destroy();
+    network->destroy();
+    return 1;
+  }
+
+  /* build engine */
+  if (dataType == DataType::kHALF)
+    builder->setHalf2Mode(true);
+
+  builder->setMaxBatchSize(maxBatchSize);
+  builder->setMaxWorkspaceSize(maxWorkspaceSize);
+  ICudaEngine *engine = builder->buildCudaEngine(*network);
+
+  /* serialize engine and write to file */
+  ofstream planFile;
+  planFile.open(planFilename);
+  IHostMemory *serializedEngine = engine->serialize();
+  planFile.write((char *)serializedEngine->data(), serializedEngine->size());
+  planFile.close(); 
+  
+  /* break down */
+  builder->destroy();
+  parser->destroy();
+  network->destroy();
+  engine->destroy();
+  serializedEngine->destroy();
+
+  return 0;
+}