expression_ssd_detect.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. Emotion Detection:
  3. Model from: https://github.com/onnx/models/blob/main/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx
  4. Model name: emotion-ferplus-8.onnx
  5. """
  6. import cv2
  7. import numpy as np
  8. import time
  9. import os
  10. from cv2 import dnn
  11. from math import ceil
  12. image_mean = np.array([127, 127, 127])
  13. image_std = 128.0
  14. iou_threshold = 0.3
  15. center_variance = 0.1
  16. size_variance = 0.2
  17. min_boxes = [
  18. [10.0, 16.0, 24.0],
  19. [32.0, 48.0],
  20. [64.0, 96.0],
  21. [128.0, 192.0, 256.0]
  22. ]
  23. strides = [8.0, 16.0, 32.0, 64.0]
  24. threshold = 0.5
  25. def define_img_size(image_size):
  26. shrinkage_list = []
  27. feature_map_w_h_list = []
  28. for size in image_size:
  29. feature_map = [int(ceil(size / stride)) for stride in strides]
  30. feature_map_w_h_list.append(feature_map)
  31. for i in range(0, len(image_size)):
  32. shrinkage_list.append(strides)
  33. priors = generate_priors(
  34. feature_map_w_h_list, shrinkage_list, image_size, min_boxes
  35. )
  36. return priors
  37. def generate_priors(
  38. feature_map_list, shrinkage_list, image_size, min_boxes
  39. ):
  40. priors = []
  41. for index in range(0, len(feature_map_list[0])):
  42. scale_w = image_size[0] / shrinkage_list[0][index]
  43. scale_h = image_size[1] / shrinkage_list[1][index]
  44. for j in range(0, feature_map_list[1][index]):
  45. for i in range(0, feature_map_list[0][index]):
  46. x_center = (i + 0.5) / scale_w
  47. y_center = (j + 0.5) / scale_h
  48. for min_box in min_boxes[index]:
  49. w = min_box / image_size[0]
  50. h = min_box / image_size[1]
  51. priors.append([
  52. x_center,
  53. y_center,
  54. w,
  55. h
  56. ])
  57. print("priors nums:{}".format(len(priors)))
  58. return np.clip(priors, 0.0, 1.0)
  59. def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
  60. scores = box_scores[:, -1]
  61. boxes = box_scores[:, :-1]
  62. picked = []
  63. indexes = np.argsort(scores)
  64. indexes = indexes[-candidate_size:]
  65. while len(indexes) > 0:
  66. current = indexes[-1]
  67. picked.append(current)
  68. if 0 < top_k == len(picked) or len(indexes) == 1:
  69. break
  70. current_box = boxes[current, :]
  71. indexes = indexes[:-1]
  72. rest_boxes = boxes[indexes, :]
  73. iou = iou_of(
  74. rest_boxes,
  75. np.expand_dims(current_box, axis=0),
  76. )
  77. indexes = indexes[iou <= iou_threshold]
  78. return box_scores[picked, :]
  79. def area_of(left_top, right_bottom):
  80. hw = np.clip(right_bottom - left_top, 0.0, None)
  81. return hw[..., 0] * hw[..., 1]
  82. def iou_of(boxes0, boxes1, eps=1e-5):
  83. overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
  84. overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
  85. overlap_area = area_of(overlap_left_top, overlap_right_bottom)
  86. area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
  87. area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
  88. return overlap_area / (area0 + area1 - overlap_area + eps)
  89. def predict(
  90. width,
  91. height,
  92. confidences,
  93. boxes,
  94. prob_threshold,
  95. iou_threshold=0.3,
  96. top_k=-1
  97. ):
  98. boxes = boxes[0]
  99. confidences = confidences[0]
  100. picked_box_probs = []
  101. picked_labels = []
  102. for class_index in range(1, confidences.shape[1]):
  103. probs = confidences[:, class_index]
  104. mask = probs > prob_threshold
  105. probs = probs[mask]
  106. if probs.shape[0] == 0:
  107. continue
  108. subset_boxes = boxes[mask, :]
  109. box_probs = np.concatenate(
  110. [subset_boxes, probs.reshape(-1, 1)], axis=1
  111. )
  112. box_probs = hard_nms(box_probs,
  113. iou_threshold=iou_threshold,
  114. top_k=top_k,
  115. )
  116. picked_box_probs.append(box_probs)
  117. picked_labels.extend([class_index] * box_probs.shape[0])
  118. if not picked_box_probs:
  119. return np.array([]), np.array([]), np.array([])
  120. picked_box_probs = np.concatenate(picked_box_probs)
  121. picked_box_probs[:, 0] *= width
  122. picked_box_probs[:, 1] *= height
  123. picked_box_probs[:, 2] *= width
  124. picked_box_probs[:, 3] *= height
  125. return (
  126. picked_box_probs[:, :4].astype(np.int32),
  127. np.array(picked_labels),
  128. picked_box_probs[:, 4]
  129. )
  130. def convert_locations_to_boxes(locations, priors, center_variance,
  131. size_variance):
  132. if len(priors.shape) + 1 == len(locations.shape):
  133. priors = np.expand_dims(priors, 0)
  134. return np.concatenate([
  135. locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2],
  136. np.exp(locations[..., 2:] * size_variance) * priors[..., 2:]
  137. ], axis=len(locations.shape) - 1)
  138. def center_form_to_corner_form(locations):
  139. return np.concatenate(
  140. [locations[..., :2] - locations[..., 2:] / 2,
  141. locations[..., :2] + locations[..., 2:] / 2],
  142. len(locations.shape) - 1
  143. )
  144. def FER_live_cam():
  145. emotion_dict = {
  146. 0: 'neutral',
  147. 1: 'happiness',
  148. 2: 'surprise',
  149. 3: 'sadness',
  150. 4: 'anger',
  151. 5: 'disgust',
  152. 6: 'fear'
  153. }
  154. cap = cv2.VideoCapture('video3.mp4')
  155. # cap = cv2.VideoCapture(0)
  156. frame_width = int(cap.get(3))
  157. frame_height = int(cap.get(4))
  158. size = (frame_width, frame_height)
  159. result = cv2.VideoWriter('infer2-test.avi',
  160. cv2.VideoWriter_fourcc(*'MJPG'),
  161. 10, size)
  162. # Read ONNX model
  163. model = 'onnx_model.onnx'
  164. model = cv2.dnn.readNetFromONNX('emotion-ferplus-8.onnx')
  165. # Read the Caffe face detector.
  166. model_path = 'RFB-320/RFB-320.caffemodel'
  167. proto_path = 'RFB-320/RFB-320.prototxt'
  168. net = dnn.readNetFromCaffe(proto_path, model_path)
  169. input_size = [320, 240]
  170. width = input_size[0]
  171. height = input_size[1]
  172. priors = define_img_size(input_size)
  173. while cap.isOpened():
  174. ret, frame = cap.read()
  175. if ret:
  176. img_ori = frame
  177. #print("frame size: ", frame.shape)
  178. rect = cv2.resize(img_ori, (width, height))
  179. rect = cv2.cvtColor(rect, cv2.COLOR_BGR2RGB)
  180. net.setInput(dnn.blobFromImage(
  181. rect, 1 / image_std, (width, height), 127)
  182. )
  183. start_time = time.time()
  184. boxes, scores = net.forward(["boxes", "scores"])
  185. boxes = np.expand_dims(np.reshape(boxes, (-1, 4)), axis=0)
  186. scores = np.expand_dims(np.reshape(scores, (-1, 2)), axis=0)
  187. boxes = convert_locations_to_boxes(
  188. boxes, priors, center_variance, size_variance
  189. )
  190. boxes = center_form_to_corner_form(boxes)
  191. boxes, labels, probs = predict(
  192. img_ori.shape[1],
  193. img_ori.shape[0],
  194. scores,
  195. boxes,
  196. threshold
  197. )
  198. gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  199. for (x1, y1, x2, y2) in boxes:
  200. w = x2 - x1
  201. h = y2 - y1
  202. cv2.rectangle(frame, (x1,y1), (x2, y2), (255,0,0), 2)
  203. resize_frame = cv2.resize(
  204. gray[y1:y1 + h, x1:x1 + w], (64, 64)
  205. )
  206. resize_frame = resize_frame.reshape(1, 1, 64, 64)
  207. model.setInput(resize_frame)
  208. output = model.forward()
  209. end_time = time.time()
  210. fps = 1 / (end_time - start_time)
  211. print(f"FPS: {fps:.1f}")
  212. pred = emotion_dict[list(output[0]).index(max(output[0]))]
  213. cv2.rectangle(
  214. img_ori,
  215. (x1, y1),
  216. (x2, y2),
  217. (215, 5, 247),
  218. 2,
  219. lineType=cv2.LINE_AA
  220. )
  221. cv2.putText(
  222. frame,
  223. pred,
  224. (x1, y1-10),
  225. cv2.FONT_HERSHEY_SIMPLEX,
  226. 0.8,
  227. (215, 5, 247),
  228. 2,
  229. lineType=cv2.LINE_AA
  230. )
  231. result.write(frame)
  232. cv2.imshow('frame', frame)
  233. if cv2.waitKey(1) & 0xFF == ord('q'):
  234. break
  235. else:
  236. break
  237. cap.release()
  238. result.release()
  239. cv2.destroyAllWindows()
  240. if __name__ == "__main__":
  241. FER_live_cam()