classify_image.cu 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. /**
  2. * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. * Full license terms provided in LICENSE.md file.
  4. */
  5. #include <iostream>
  6. #include <fstream>
  7. #include <sstream>
  8. #include <vector>
  9. #include <NvInfer.h>
  10. #include <opencv2/opencv.hpp>
  11. #include "examples/classify_image/utils.h"
  12. using namespace std;
  13. using namespace nvinfer1;
  14. class Logger : public ILogger
  15. {
  16. void log(Severity severity, const char * msg) override
  17. {
  18. if (severity != Severity::kINFO)
  19. cout << msg << endl;
  20. }
  21. } gLogger;
  22. /**
  23. * image_file: path to image
  24. * plan_file: path of the serialized engine file
  25. * label_file: file with <class_name> per line
  26. * input_name: name of the input tensor
  27. * output_name: name of the output tensor
  28. * preprocessing_fn: 'vgg' or 'inception'
  29. */
  30. int main(int argc, char *argv[])
  31. {
  32. if (argc != 7)
  33. {
  34. cout << "Usage: classify_image <image_file> <plan_file> <label_file> <input_name> <output_name> <preprocessing_fn>\n";
  35. return 0;
  36. }
  37. string imageFilename = argv[1];
  38. string planFilename = argv[2];
  39. string labelFilename = argv[3];
  40. string inputName = argv[4];
  41. string outputName = argv[5];
  42. string preprocessingFn = argv[6];
  43. /* load the engine */
  44. cout << "Loading TensorRT engine from plan file..." << endl;
  45. ifstream planFile(planFilename);
  46. if (!planFile.is_open())
  47. {
  48. cout << "Could not open plan file." << endl;
  49. return 1;
  50. }
  51. stringstream planBuffer;
  52. planBuffer << planFile.rdbuf();
  53. string plan = planBuffer.str();
  54. IRuntime *runtime = createInferRuntime(gLogger);
  55. ICudaEngine *engine = runtime->deserializeCudaEngine((void*)plan.data(), plan.size(), nullptr);
  56. IExecutionContext *context = engine->createExecutionContext();
  57. /* get the input / output dimensions */
  58. int inputBindingIndex, outputBindingIndex;
  59. inputBindingIndex = engine->getBindingIndex(inputName.c_str());
  60. outputBindingIndex = engine->getBindingIndex(outputName.c_str());
  61. if (inputBindingIndex < 0)
  62. {
  63. cout << "Invalid input name." << endl;
  64. return 1;
  65. }
  66. if (outputBindingIndex < 0)
  67. {
  68. cout << "Invalid output name." << endl;
  69. return 1;
  70. }
  71. Dims inputDims, outputDims;
  72. inputDims = engine->getBindingDimensions(inputBindingIndex);
  73. outputDims = engine->getBindingDimensions(outputBindingIndex);
  74. int inputWidth, inputHeight;
  75. inputHeight = inputDims.d[1];
  76. inputWidth = inputDims.d[2];
  77. /* read image, convert color, and resize */
  78. cout << "Preprocessing input..." << endl;
  79. cv::Mat image = cv::imread(imageFilename, CV_LOAD_IMAGE_COLOR);
  80. if (image.data == NULL)
  81. {
  82. cout << "Could not read image from file." << endl;
  83. return 1;
  84. }
  85. cv::cvtColor(image, image, cv::COLOR_BGR2RGB, 3);
  86. cv::resize(image, image, cv::Size(inputWidth, inputHeight));
  87. /* convert from uint8+NHWC to float+NCHW */
  88. float *inputDataHost, *outputDataHost;
  89. size_t numInput, numOutput;
  90. numInput = numTensorElements(inputDims);
  91. numOutput = numTensorElements(outputDims);
  92. inputDataHost = (float*) malloc(numInput * sizeof(float));
  93. outputDataHost = (float*) malloc(numOutput * sizeof(float));
  94. cvImageToTensor(image, inputDataHost, inputDims);
  95. if (preprocessingFn == "vgg")
  96. preprocessVgg(inputDataHost, inputDims);
  97. else if (preprocessingFn == "inception")
  98. preprocessInception(inputDataHost, inputDims);
  99. else
  100. {
  101. cout << "Invalid preprocessing function argument, must be vgg or inception. \n" << endl;
  102. return 1;
  103. }
  104. /* transfer to device */
  105. float *inputDataDevice, *outputDataDevice;
  106. cudaMalloc(&inputDataDevice, numInput * sizeof(float));
  107. cudaMalloc(&outputDataDevice, numOutput * sizeof(float));
  108. cudaMemcpy(inputDataDevice, inputDataHost, numInput * sizeof(float), cudaMemcpyHostToDevice);
  109. void *bindings[2];
  110. bindings[inputBindingIndex] = (void*) inputDataDevice;
  111. bindings[outputBindingIndex] = (void*) outputDataDevice;
  112. /* execute engine */
  113. cout << "Executing inference engine..." << endl;
  114. const int kBatchSize = 1;
  115. context->execute(kBatchSize, bindings);
  116. /* transfer output back to host */
  117. cudaMemcpy(outputDataHost, outputDataDevice, numOutput * sizeof(float), cudaMemcpyDeviceToHost);
  118. /* parse output */
  119. vector<size_t> sortedIndices = argsort(outputDataHost, outputDims);
  120. cout << "\nThe top-5 indices are: ";
  121. for (int i = 0; i < 5; i++)
  122. cout << sortedIndices[i] << " ";
  123. ifstream labelsFile(labelFilename);
  124. if (!labelsFile.is_open())
  125. {
  126. cout << "\nCould not open label file." << endl;
  127. return 1;
  128. }
  129. vector<string> labelMap;
  130. string label;
  131. while(getline(labelsFile, label))
  132. {
  133. labelMap.push_back(label);
  134. }
  135. cout << "\nWhich corresponds to class labels: ";
  136. for (int i = 0; i < 5; i++)
  137. cout << endl << i << ". " << labelMap[sortedIndices[i]];
  138. cout << endl;
  139. /* clean up */
  140. runtime->destroy();
  141. engine->destroy();
  142. context->destroy();
  143. free(inputDataHost);
  144. free(outputDataHost);
  145. cudaFree(inputDataDevice);
  146. cudaFree(outputDataDevice);
  147. return 0;
  148. }