mbnet.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import cv2
  2. import numpy as np
  3. # import time
  4. # video_path = 'D:/OfficeWork/VS_code_exp/exp/video_1.mp4'
  5. # image_path = 'D:/OfficeWork/VS_code_exp/exp/test.jpg.jpg'
  6. def load_model():
  7. model = cv2.dnn.readNet(
  8. model="frozen_inference_graph.pb", config="ssd_mobilenet_v2_coco_2018_03_29.pbtxt.txt", framework="TensorFlow"
  9. )
  10. with open("object_detection_classes_coco.txt", "r") as f:
  11. class_names = f.read().split("\n")
  12. COLORS = np.random.uniform(0, 255, size=(len(class_names), 3))
  13. return model, class_names, COLORS
  14. def load_img(img_path):
  15. img = cv2.imread(img_path)
  16. img = cv2.resize(img, None, fx=0.4, fy=0.4)
  17. height, width, channels = img.shape
  18. return img, height, width, channels
  19. def detect_objects(img, net):
  20. blob = cv2.dnn.blobFromImage(img, size=(300, 300), mean=(104, 117, 123), swapRB=True)
  21. net.setInput(blob)
  22. outputs = net.forward()
  23. print(outputs.shape)
  24. # print (outputs)
  25. return blob, outputs
  26. def get_box_dimensions(outputs, height, width):
  27. boxes = []
  28. class_ids = []
  29. for detect in outputs[0, 0, :, :]:
  30. scores = detect[2]
  31. class_id = detect[1]
  32. if scores > 0.3:
  33. center_x = int(detect[0] * width)
  34. center_y = int(detect[1] * height)
  35. w = int(detect[5] * width)
  36. h = int(detect[6] * height)
  37. x = int((detect[3] * width))
  38. y = int((detect[4] * height))
  39. boxes.append([x, y, w, h])
  40. class_ids.append(class_id)
  41. return boxes, class_ids
  42. def draw_labels(boxes, colors, class_ids, classes, img):
  43. font = cv2.FONT_HERSHEY_PLAIN
  44. model, classes, colors = load_model()
  45. for i in range(len(boxes)):
  46. x, y, w, h = boxes[i]
  47. label = classes[int(class_ids[0]) - 1]
  48. color = colors[i]
  49. cv2.rectangle(img, (x, y), (w, h), color, 5)
  50. cv2.putText(img, label, (x, y - 5), font, 5, color, 5)
  51. return img
  52. def image_detect(img_path):
  53. model, classes, colors = load_model()
  54. image, height, width, channels = load_img(img_path)
  55. blob, outputs = detect_objects(image, model)
  56. boxes, class_ids = get_box_dimensions(outputs, height, width)
  57. image1 = draw_labels(boxes, colors, class_ids, classes, image)
  58. return image1
  59. # def start_video(video_path):
  60. model, classes, colors = load_model()
  61. cap = cv2.VideoCapture(video_path)
  62. while True:
  63. _, frame = cap.read()
  64. height, width, channels = frame.shape
  65. blob, outputs = detect_objects(frame, model)
  66. boxes, class_ids = get_box_dimensions(outputs, height, width)
  67. frame = draw_labels(boxes, colors, class_ids, classes, frame)
  68. yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  69. cv2.destroyAllWindows()