overlay_with_mask.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import argparse
  2. import csv
  3. import os
  4. import pprint
  5. from collections import OrderedDict
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import lib.models as models
  10. from lib.config import (
  11. config,
  12. update_config,
  13. )
  14. from lib.core.evaluation import decode_preds
  15. from lib.utils import utils
  16. from lib.utils.transforms import crop
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description="Face Mask Overlay")
  19. parser.add_argument(
  20. "--cfg", help="experiment configuration filename", required=True, type=str,
  21. )
  22. parser.add_argument(
  23. "--landmark_model",
  24. help="path to model for landmarks exctraction",
  25. required=True,
  26. type=str,
  27. )
  28. parser.add_argument(
  29. "--detector_model",
  30. help="path to detector model",
  31. type=str,
  32. default="detection/face_detector.prototxt",
  33. )
  34. parser.add_argument(
  35. "--detector_weights",
  36. help="path to detector weights",
  37. type=str,
  38. default="detection/face_detector.caffemodel",
  39. )
  40. parser.add_argument(
  41. "--mask_image", help="path to a .png file with a mask", required=True, type=str,
  42. )
  43. parser.add_argument("--device", default="cpu", help="Device to inference on")
  44. args = parser.parse_args()
  45. update_config(config, args)
  46. return args
  47. def main():
  48. # parsing script arguments
  49. args = parse_args()
  50. device = torch.device(args.device)
  51. # initialize logger
  52. logger, final_output_dir, tb_log_dir = utils.create_logger(config, args.cfg, "demo")
  53. # log arguments and config values
  54. logger.info(pprint.pformat(args))
  55. logger.info(pprint.pformat(config))
  56. # init landmark model
  57. model = models.get_face_alignment_net(config)
  58. # get input size from the config
  59. input_size = config.MODEL.IMAGE_SIZE
  60. # load model
  61. state_dict = torch.load(args.landmark_model, map_location=device)
  62. # remove `module.` prefix from the pre-trained weights
  63. new_state_dict = OrderedDict()
  64. for key, value in state_dict.items():
  65. name = key[7:]
  66. new_state_dict[name] = value
  67. # load weights without the prefix
  68. model.load_state_dict(new_state_dict)
  69. # run model on device
  70. model = model.to(device)
  71. # init mean and std values for the landmark model's input
  72. mean = config.MODEL.MEAN
  73. mean = np.array(mean, dtype=np.float32)
  74. std = config.MODEL.STD
  75. std = np.array(std, dtype=np.float32)
  76. # defining prototxt and caffemodel paths
  77. detector_model = args.detector_model
  78. detector_weights = args.detector_weights
  79. # load model
  80. detector = cv2.dnn.readNetFromCaffe(detector_model, detector_weights)
  81. capture = cv2.VideoCapture(0)
  82. frame_num = 0
  83. while True:
  84. # capture frame-by-frame
  85. success, frame = capture.read()
  86. # break if no frame
  87. if not success:
  88. break
  89. frame_num += 1
  90. print("frame_num: ", frame_num)
  91. landmarks_img = frame.copy()
  92. result = frame.copy()
  93. result = result.astype(np.float32) / 255.0
  94. # get frame's height and width
  95. height, width = frame.shape[:2] # 640x480
  96. # resize and subtract BGR mean values, since Caffe uses BGR images for input
  97. blob = cv2.dnn.blobFromImage(
  98. frame, scalefactor=1.0, size=(300, 300), mean=(104.0, 177.0, 123.0),
  99. )
  100. # passing blob through the network to detect faces
  101. detector.setInput(blob)
  102. # detector output format:
  103. # [image_id, class, confidence, left, bottom, right, top]
  104. face_detections = detector.forward()
  105. # loop over the detections
  106. for i in range(0, face_detections.shape[2]):
  107. # extract confidence
  108. confidence = face_detections[0, 0, i, 2]
  109. # filter detections by confidence greater than the minimum threshold
  110. if confidence > 0.5:
  111. # get coordinates of the bounding box
  112. box = face_detections[0, 0, i, 3:7] * np.array(
  113. [width, height, width, height],
  114. )
  115. (x1, y1, x2, y2) = box.astype("int")
  116. # show original image
  117. cv2.imshow("original image", frame)
  118. # crop to detection and resize
  119. resized = crop(
  120. frame,
  121. torch.Tensor([x1 + (x2 - x1) / 2, y1 + (y2 - y1) / 2]),
  122. 1.5,
  123. tuple(input_size),
  124. )
  125. # convert from BGR to RGB since HRNet expects RGB format
  126. resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
  127. img = resized.astype(np.float32) / 255.0
  128. # normalize landmark net input
  129. normalized_img = (img - mean) / std
  130. # predict face landmarks
  131. model = model.eval()
  132. with torch.no_grad():
  133. input = torch.Tensor(normalized_img.transpose([2, 0, 1]))
  134. input = input.to(device)
  135. output = model(input.unsqueeze(0))
  136. score_map = output.data.cpu()
  137. preds = decode_preds(
  138. score_map,
  139. [torch.Tensor([x1 + (x2 - x1) / 2, y1 + (y2 - y1) / 2])],
  140. [1.5],
  141. score_map.shape[2:4],
  142. )
  143. preds = preds.squeeze(0)
  144. landmarks = preds.data.cpu().detach().numpy()
  145. # draw landmarks
  146. for k, landmark in enumerate(landmarks, 1):
  147. landmarks_img = cv2.circle(
  148. landmarks_img,
  149. center=(landmark[0], landmark[1]),
  150. radius=3,
  151. color=(0, 0, 255),
  152. thickness=-1,
  153. )
  154. # draw landmarks' labels
  155. landmarks_img = cv2.putText(
  156. img=landmarks_img,
  157. text=str(k),
  158. org=(int(landmark[0]) + 5, int(landmark[1]) + 5),
  159. fontFace=cv2.FONT_HERSHEY_SIMPLEX,
  160. fontScale=0.5,
  161. color=(0, 0, 255),
  162. )
  163. # show results by drawing predicted landmarks and their labels
  164. cv2.imshow("image with landmarks", landmarks_img)
  165. # get chosen landmarks 2-16, 30 as destination points
  166. # note that landmarks numbering starts from 0
  167. dst_pts = np.array(
  168. [
  169. landmarks[1],
  170. landmarks[2],
  171. landmarks[3],
  172. landmarks[4],
  173. landmarks[5],
  174. landmarks[6],
  175. landmarks[7],
  176. landmarks[8],
  177. landmarks[9],
  178. landmarks[10],
  179. landmarks[11],
  180. landmarks[12],
  181. landmarks[13],
  182. landmarks[14],
  183. landmarks[15],
  184. landmarks[29],
  185. ],
  186. dtype="float32",
  187. )
  188. # load mask annotations from csv file to source points
  189. mask_annotation = os.path.splitext(os.path.basename(args.mask_image))[0]
  190. mask_annotation = os.path.join(
  191. os.path.dirname(args.mask_image), mask_annotation + ".csv",
  192. )
  193. with open(mask_annotation) as csv_file:
  194. csv_reader = csv.reader(csv_file, delimiter=",")
  195. src_pts = []
  196. for i, row in enumerate(csv_reader):
  197. # skip head or empty line if it's there
  198. try:
  199. src_pts.append(np.array([float(row[1]), float(row[2])]))
  200. except ValueError:
  201. continue
  202. src_pts = np.array(src_pts, dtype="float32")
  203. # overlay with a mask only if all landmarks have positive coordinates:
  204. if (landmarks > 0).all():
  205. # load mask image
  206. mask_img = cv2.imread(args.mask_image, cv2.IMREAD_UNCHANGED)
  207. mask_img = mask_img.astype(np.float32)
  208. mask_img = mask_img / 255.0
  209. # get the perspective transformation matrix
  210. M, _ = cv2.findHomography(src_pts, dst_pts)
  211. # transformed masked image
  212. transformed_mask = cv2.warpPerspective(
  213. mask_img,
  214. M,
  215. (result.shape[1], result.shape[0]),
  216. None,
  217. cv2.INTER_LINEAR,
  218. cv2.BORDER_CONSTANT,
  219. )
  220. # mask overlay
  221. alpha_mask = transformed_mask[:, :, 3]
  222. alpha_image = 1.0 - alpha_mask
  223. for c in range(0, 3):
  224. result[:, :, c] = (
  225. alpha_mask * transformed_mask[:, :, c]
  226. + alpha_image * result[:, :, c]
  227. )
  228. # display the resulting frame
  229. cv2.imshow("image with mask overlay", result)
  230. # waiting for the escape button to exit
  231. k = cv2.waitKey(1)
  232. if k == 27:
  233. break
  234. # when everything done, release the capture
  235. capture.release()
  236. cv2.destroyAllWindows()
  237. if __name__ == "__main__":
  238. main()