collect_data.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Collect dataset
  2. # Import required modules
  3. import cv2
  4. import numpy as np
  5. import depthai as dai
  6. import os
  7. import blobconverter
  8. # Define directory paths
  9. real_face_dir = os.path.join("dataset", "real")
  10. os.makedirs(real_face_dir, exist_ok=True)
  11. spoofed_face_dir = os.path.join("dataset", "spoofed")
  12. os.makedirs(spoofed_face_dir, exist_ok=True)
  13. # Define Detection NN model name and input size
  14. DET_INPUT_SIZE = (300, 300)
  15. FACE_MODEL_NAME = "face-detection-retail-0004"
  16. ZOO_TYPE = "depthai"
  17. # Start defining a pipeline
  18. pipeline = dai.Pipeline()
  19. # Define a source - two mono (grayscale) cameras
  20. left = pipeline.createMonoCamera()
  21. left.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P)
  22. left.setBoardSocket(dai.CameraBoardSocket.LEFT)
  23. right = pipeline.createMonoCamera()
  24. right.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P)
  25. right.setBoardSocket(dai.CameraBoardSocket.RIGHT)
  26. # Create a node that will produce the depth map (using disparity output as it's easier to visualize depth this way)
  27. depth = pipeline.createStereoDepth()
  28. depth.setConfidenceThreshold(200)
  29. depth.setOutputRectified(True) # The rectified streams are horizontally mirrored by default
  30. depth.setRectifyEdgeFillColor(0) # Black, to better see the cutout
  31. depth.setExtendedDisparity(True) # For better close range depth perception
  32. # Convert model from OMZ to blob
  33. if FACE_MODEL_NAME is not None:
  34. blob_path = blobconverter.from_zoo(
  35. name=FACE_MODEL_NAME,
  36. shaves=6,
  37. zoo_type=ZOO_TYPE
  38. )
  39. # Define face detection NN node
  40. faceDetNn = pipeline.createMobileNetDetectionNetwork()
  41. faceDetNn.setConfidenceThreshold(0.75)
  42. faceDetNn.setBlobPath(blob_path)
  43. # Define face detection input config
  44. faceDetManip = pipeline.createImageManip()
  45. faceDetManip.initialConfig.setResize(DET_INPUT_SIZE[0], DET_INPUT_SIZE[1])
  46. faceDetManip.initialConfig.setKeepAspectRatio(False)
  47. faceDetManip.initialConfig.setFrameType(dai.RawImgFrame.Type.RGB888p)
  48. # Linking
  49. depth.rectifiedRight.link(faceDetManip.inputImage)
  50. faceDetManip.out.link(faceDetNn.input)
  51. # Create face detection output
  52. xOutDet = pipeline.createXLinkOut()
  53. xOutDet.setStreamName('det_out')
  54. faceDetNn.out.link(xOutDet.input)
  55. # Options: MEDIAN_OFF, KERNEL_3x3, KERNEL_5x5, KERNEL_7x7 (default)
  56. median = dai.StereoDepthProperties.MedianFilter.KERNEL_7x7 # For depth filtering
  57. depth.setMedianFilter(median)
  58. left.out.link(depth.left)
  59. right.out.link(depth.right)
  60. # Create left output
  61. xout_right = pipeline.createXLinkOut()
  62. xout_right.setStreamName("right")
  63. depth.rectifiedRight.link(xout_right.input)
  64. # Create depth output
  65. xout = pipeline.createXLinkOut()
  66. xout.setStreamName("disparity")
  67. depth.disparity.link(xout.input)
  68. # Initialize wlsFilter
  69. # wlsFilter = cv2.ximgproc.createDisparityWLSFilterGeneric(False)
  70. # wlsFilter.setLambda(8000)
  71. # wlsFilter.setSigmaColor(1.5)
  72. # Frame count
  73. count = 0
  74. # Set the number of frames to skip
  75. SKIP_FRAMES = 10
  76. real_count = 0
  77. spoofed_count = 0
  78. save_real = False
  79. save_spoofed = False
  80. # Pipeline defined, now the device is connected to
  81. with dai.Device(pipeline) as device:
  82. # Start pipeline
  83. device.startPipeline()
  84. # Output queue will be used to get the right camera frames from the outputs defined above
  85. q_right = device.getOutputQueue(name="right", maxSize=4, blocking=False)
  86. # Output queue will be used to get the disparity frames from the outputs defined above
  87. q_depth = device.getOutputQueue(name="disparity", maxSize=4, blocking=False)
  88. # Output queue will be used to get nn detection data from the video frames.
  89. qDet = device.getOutputQueue(name="det_out", maxSize=1, blocking=False)
  90. while True:
  91. # Get right camera frame
  92. in_right = q_right.get()
  93. r_frame = in_right.getFrame()
  94. # r_frame = cv2.flip(r_frame, flipCode=1)
  95. # cv2.imshow("right", r_frame)
  96. # Get depth frame
  97. in_depth = q_depth.get() # blocking call, will wait until a new data has arrived
  98. depth_frame = in_depth.getFrame()
  99. depth_frame = np.ascontiguousarray(depth_frame)
  100. depth_frame = cv2.bitwise_not(depth_frame)
  101. depth_frame = cv2.flip(depth_frame, flipCode=1)
  102. # Apply wls filter
  103. # cv2.imshow("without wls filter", cv2.applyColorMap(depth_frame, cv2.COLORMAP_JET))
  104. # depth_frame = wlsFilter.filter(depth_frame, r_frame)
  105. # frame is transformed, the color map will be applied to highlight the depth info
  106. depth_frame_cmap = cv2.applyColorMap(depth_frame, cv2.COLORMAP_JET)
  107. # frame is ready to be shown
  108. cv2.imshow("disparity", depth_frame_cmap)
  109. # Retrieve 'bgr' (opencv format) frame from gray scale
  110. frame = cv2.cvtColor(r_frame, cv2.COLOR_GRAY2RGB)
  111. img_h, img_w = frame.shape[0:2]
  112. bbox = None
  113. inDet = qDet.tryGet()
  114. if inDet is not None:
  115. detections = inDet.detections
  116. # for detection in detections:
  117. if len(detections) is not 0:
  118. detection = detections[0]
  119. # print(detection.confidence)
  120. x = int(detection.xmin * img_w)
  121. y = int(detection.ymin * img_h)
  122. w = int(detection.xmax * img_w - detection.xmin * img_w)
  123. h = int(detection.ymax * img_h - detection.ymin * img_h)
  124. bbox = (x, y, w, h)
  125. # Set default status
  126. status_color = (0, 0, 255)
  127. # Check if a face was detected in the frame
  128. if bbox:
  129. # Get face roi from right and depth frames
  130. face_d = depth_frame[max(0, bbox[1]):bbox[1] + bbox[3], max(0, bbox[0]):bbox[0] + bbox[2]]
  131. face_r = r_frame[max(0, bbox[1]):bbox[1] + bbox[3], max(0, bbox[0]):bbox[0] + bbox[2]]
  132. cv2.imshow("face_roi", face_d)
  133. # Display bounding box
  134. cv2.rectangle(frame, bbox, status_color, 2)
  135. # Create background for showing details
  136. cv2.rectangle(frame, (5, 5, 225, 100), (50, 0, 0), -1)
  137. # Display instructions on the frame
  138. cv2.putText(frame, 'Press F to save real Face.', (10, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255))
  139. cv2.putText(frame, 'Press S to save spoofed Face.', (10, 65), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255))
  140. cv2.putText(frame, 'Press Q to Quit.', (10, 85), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255))
  141. # Capture the key pressed
  142. key_pressed = cv2.waitKey(1) & 0xff
  143. # Save face depth map is f was pressed
  144. if key_pressed == ord('f'):
  145. save_real = not save_real
  146. save_spoofed = False
  147. # Save face depth map is s was pressed
  148. elif key_pressed == ord('s'):
  149. save_spoofed = not save_spoofed
  150. save_real = False
  151. # Stop the program if q was pressed
  152. elif key_pressed == ord('q'):
  153. break
  154. if bbox:
  155. if face_d is not None and save_real:
  156. real_count += 1
  157. filename = f"real_{real_count}.jpg"
  158. cv2.imwrite(os.path.join(real_face_dir, filename), face_d)
  159. # Display authentication status on the frame
  160. cv2.putText(frame, f"Saving real face: {real_count}", (20, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
  161. elif face_d is not None and save_spoofed:
  162. spoofed_count += 1
  163. filename = f"spoofed_{spoofed_count}.jpg"
  164. cv2.imwrite(os.path.join(spoofed_face_dir, filename), face_d)
  165. # Display authentication status on the frame
  166. cv2.putText(frame, f"Saving spoofed face: {spoofed_count}", (20, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
  167. else:
  168. cv2.putText(frame, "Face not saved", (20, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
  169. # Display the final frame
  170. cv2.imshow("Data collection Cam", frame)
  171. count += 1
  172. cv2.destroyAllWindows()