Browse Source

added argument checking to classify_image

John Welsh 6 years ago
parent
commit
9c9b213b1f
2 changed files with 35 additions and 1 deletions
  1. 34 0
      examples/classify_image/classify_image.cu
  2. 1 1
      scripts/convert_plan.py

+ 34 - 0
examples/classify_image/classify_image.cu

@@ -52,6 +52,13 @@ int main(int argc, char *argv[])
   /* load the engine */
   cout << "Loading TensorRT engine from plan file..." << endl;
   ifstream planFile(planFilename); 
+
+  if (!planFile.is_open())
+  {
+    cout << "Could not open plan file." << endl;
+    return 1;
+  }
+
   stringstream planBuffer;
   planBuffer << planFile.rdbuf();
   string plan = planBuffer.str();
@@ -63,6 +70,19 @@ int main(int argc, char *argv[])
   int inputBindingIndex, outputBindingIndex;
   inputBindingIndex = engine->getBindingIndex(inputName.c_str());
   outputBindingIndex = engine->getBindingIndex(outputName.c_str());
+
+  if (inputBindingIndex < 0)
+  {
+    cout << "Invalid input name." << endl;
+    return 1;
+  }
+
+  if (outputBindingIndex < 0)
+  {
+    cout << "Invalid output name." << endl;
+    return 1;
+  }
+
   Dims inputDims, outputDims;
   inputDims = engine->getBindingDimensions(inputBindingIndex);
   outputDims = engine->getBindingDimensions(outputBindingIndex);
@@ -73,6 +93,13 @@ int main(int argc, char *argv[])
   /* read image, convert color, and resize */
   cout << "Preprocessing input..." << endl;
   cv::Mat image = cv::imread(imageFilename, CV_LOAD_IMAGE_COLOR);
+
+  if (image.data == NULL)
+  {
+    cout << "Could not read image from file." << endl;
+    return 1;
+  }
+
   cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
   cv::resize(image, image, cv::Size(inputWidth, inputHeight));
 
@@ -119,6 +146,13 @@ int main(int argc, char *argv[])
     cout << sortedIndices[i] << " ";
 
   ifstream labelsFile(labelFilename);
+
+  if (!labelsFile.is_open())
+  {
+    cout << "\nCould not open label file." << endl;
+    return 1;
+  }
+
   vector<string> labelMap;
   string label;
   while(getline(labelsFile, label))

+ 1 - 1
scripts/convert_plan.py

@@ -19,7 +19,7 @@ def frozenToPlan(frozen_graph_filename, plan_filename, input_name, input_height,
         frozen_file=frozen_graph_filename,
         output_nodes=[output_name],
         output_filename=TMP_UFF_FILENAME,
-        text=False
+        text=False,
     )
 
     # convert frozen graph to engine (plan)