spatial-object-tracker.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from pathlib import Path
  2. import cv2
  3. import depthai as dai
  4. import time
  5. import argparse
  6. import blobconverter
  7. labelMap = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
  8. "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
  9. nnPathDefault = blobconverter.from_zoo(
  10. name="mobilenet-ssd",
  11. shaves=6,
  12. zoo_type="intel"
  13. )
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('nnPath', nargs='?', help="Path to mobilenet detection network blob", default=nnPathDefault)
  16. parser.add_argument('-ff', '--full_frame', action="store_true", help="Perform tracking on full RGB frame", default=False)
  17. args = parser.parse_args()
  18. fullFrameTracking = args.full_frame
  19. # Create pipeline
  20. pipeline = dai.Pipeline()
  21. # Define sources and outputs
  22. camRgb = pipeline.create(dai.node.ColorCamera)
  23. spatialDetectionNetwork = pipeline.create(dai.node.MobileNetSpatialDetectionNetwork)
  24. monoLeft = pipeline.create(dai.node.MonoCamera)
  25. monoRight = pipeline.create(dai.node.MonoCamera)
  26. stereo = pipeline.create(dai.node.StereoDepth)
  27. objectTracker = pipeline.create(dai.node.ObjectTracker)
  28. xoutRgb = pipeline.create(dai.node.XLinkOut)
  29. trackerOut = pipeline.create(dai.node.XLinkOut)
  30. xoutRgb.setStreamName("preview")
  31. trackerOut.setStreamName("tracklets")
  32. # Properties
  33. camRgb.setPreviewSize(300, 300)
  34. camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
  35. camRgb.setInterleaved(False)
  36. camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
  37. monoLeft.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P)
  38. monoLeft.setBoardSocket(dai.CameraBoardSocket.LEFT)
  39. monoRight.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P)
  40. monoRight.setBoardSocket(dai.CameraBoardSocket.RIGHT)
  41. # setting node configs
  42. stereo.initialConfig.setConfidenceThreshold(255)
  43. spatialDetectionNetwork.setBlobPath(args.nnPath)
  44. spatialDetectionNetwork.setConfidenceThreshold(0.5)
  45. spatialDetectionNetwork.input.setBlocking(False)
  46. spatialDetectionNetwork.setBoundingBoxScaleFactor(0.5)
  47. spatialDetectionNetwork.setDepthLowerThreshold(100)
  48. spatialDetectionNetwork.setDepthUpperThreshold(5000)
  49. objectTracker.setDetectionLabelsToTrack([15]) # track only person
  50. # possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS
  51. objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
  52. # take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
  53. objectTracker.setTrackerIdAssigmentPolicy(dai.TrackerIdAssigmentPolicy.SMALLEST_ID)
  54. # Linking
  55. monoLeft.out.link(stereo.left)
  56. monoRight.out.link(stereo.right)
  57. camRgb.preview.link(spatialDetectionNetwork.input)
  58. objectTracker.passthroughTrackerFrame.link(xoutRgb.input)
  59. objectTracker.out.link(trackerOut.input)
  60. if fullFrameTracking:
  61. camRgb.setPreviewKeepAspectRatio(False)
  62. camRgb.video.link(objectTracker.inputTrackerFrame)
  63. objectTracker.inputTrackerFrame.setBlocking(False)
  64. # do not block the pipeline if it's too slow on full frame
  65. objectTracker.inputTrackerFrame.setQueueSize(2)
  66. else:
  67. spatialDetectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
  68. spatialDetectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
  69. spatialDetectionNetwork.out.link(objectTracker.inputDetections)
  70. stereo.depth.link(spatialDetectionNetwork.inputDepth)
  71. # Connect to device and start pipeline
  72. with dai.Device(pipeline) as device:
  73. preview = device.getOutputQueue("preview", 4, False)
  74. tracklets = device.getOutputQueue("tracklets", 4, False)
  75. startTime = time.monotonic()
  76. counter = 0
  77. fps = 0
  78. color = (255, 255, 255)
  79. while(True):
  80. imgFrame = preview.get()
  81. track = tracklets.get()
  82. counter+=1
  83. current_time = time.monotonic()
  84. if (current_time - startTime) > 1 :
  85. fps = counter / (current_time - startTime)
  86. counter = 0
  87. startTime = current_time
  88. frame = imgFrame.getCvFrame()
  89. trackletsData = track.tracklets
  90. for t in trackletsData:
  91. roi = t.roi.denormalize(frame.shape[1], frame.shape[0])
  92. x1 = int(roi.topLeft().x)
  93. y1 = int(roi.topLeft().y)
  94. x2 = int(roi.bottomRight().x)
  95. y2 = int(roi.bottomRight().y)
  96. try:
  97. label = labelMap[t.label]
  98. except:
  99. label = t.label
  100. cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  101. cv2.putText(frame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  102. cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  103. cv2.rectangle(frame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX)
  104. cv2.putText(frame, f"X: {int(t.spatialCoordinates.x)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  105. cv2.putText(frame, f"Y: {int(t.spatialCoordinates.y)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  106. cv2.putText(frame, f"Z: {int(t.spatialCoordinates.z)} mm", (x1 + 10, y1 + 95), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  107. cv2.putText(frame, "NN fps: {:.2f}".format(fps), (2, frame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color)
  108. cv2.imshow("tracker", frame)
  109. if cv2.waitKey(1) == 27:
  110. break