yolov7_keypoint.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import matplotlib.pyplot as plt
  2. import torch
  3. import cv2
  4. import numpy as np
  5. import time
  6. from torchvision import transforms
  7. from utils.datasets import letterbox
  8. from utils.general import non_max_suppression_kpt
  9. from utils.plots import output_to_keypoint, plot_skeleton_kpts
  10. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  11. weigths = torch.load('yolov7-w6-pose.pt')
  12. model = weigths['model']
  13. model = model.half().to(device)
  14. _ = model.eval()
  15. video_path = '../inference_data/video_5.mp4'
  16. cap = cv2.VideoCapture(video_path)
  17. if (cap.isOpened() == False):
  18. print('Error while trying to read video. Please check path again')
  19. # Get the frame width and height.
  20. frame_width = int(cap.get(3))
  21. frame_height = int(cap.get(4))
  22. # Pass the first frame through `letterbox` function to get the resized image,
  23. # to be used for `VideoWriter` dimensions. Resize by larger side.
  24. vid_write_image = letterbox(cap.read()[1], (frame_width), stride=64, auto=True)[0]
  25. resize_height, resize_width = vid_write_image.shape[:2]
  26. save_name = f"{video_path.split('/')[-1].split('.')[0]}"
  27. # Define codec and create VideoWriter object .
  28. out = cv2.VideoWriter(f"{save_name}_keypoint.mp4",
  29. cv2.VideoWriter_fourcc(*'mp4v'), 30,
  30. (resize_width, resize_height))
  31. frame_count = 0 # To count total frames.
  32. total_fps = 0 # To get the final frames per second.
  33. while(cap.isOpened):
  34. # Capture each frame of the video.
  35. ret, frame = cap.read()
  36. if ret:
  37. orig_image = frame
  38. image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
  39. image = letterbox(image, (frame_width), stride=64, auto=True)[0]
  40. image_ = image.copy()
  41. image = transforms.ToTensor()(image)
  42. image = torch.tensor(np.array([image.numpy()]))
  43. image = image.to(device)
  44. image = image.half()
  45. # Get the start time.
  46. start_time = time.time()
  47. with torch.no_grad():
  48. output, _ = model(image)
  49. # Get the end time.
  50. end_time = time.time()
  51. # Get the fps.
  52. fps = 1 / (end_time - start_time)
  53. # Add fps to total fps.
  54. total_fps += fps
  55. # Increment frame count.
  56. frame_count += 1
  57. output = non_max_suppression_kpt(output, 0.25, 0.65, nc=model.yaml['nc'], nkpt=model.yaml['nkpt'], kpt_label=True)
  58. output = output_to_keypoint(output)
  59. nimg = image[0].permute(1, 2, 0) * 255
  60. nimg = nimg.cpu().numpy().astype(np.uint8)
  61. nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
  62. for idx in range(output.shape[0]):
  63. plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
  64. # Comment/Uncomment the following lines to show bounding boxes around persons.
  65. xmin, ymin = (output[idx, 2]-output[idx, 4]/2), (output[idx, 3]-output[idx, 5]/2)
  66. xmax, ymax = (output[idx, 2]+output[idx, 4]/2), (output[idx, 3]+output[idx, 5]/2)
  67. cv2.rectangle(
  68. nimg,
  69. (int(xmin), int(ymin)),
  70. (int(xmax), int(ymax)),
  71. color=(255, 0, 0),
  72. thickness=2,
  73. lineType=cv2.LINE_AA
  74. )
  75. # Write the FPS on the current frame.
  76. cv2.putText(nimg, f"{fps:.3f} FPS", (15, 30), cv2.FONT_HERSHEY_SIMPLEX,
  77. 1, (0, 255, 0), 2)
  78. # Convert from BGR to RGB color format.
  79. cv2.imshow('image', nimg)
  80. out.write(nimg)
  81. # Press `q` to exit.
  82. if cv2.waitKey(1) & 0xFF == ord('q'):
  83. break
  84. else:
  85. break
  86. # Release VideoCapture().
  87. cap.release()
  88. # Close all frames and video windows.
  89. cv2.destroyAllWindows()
  90. # Calculate and print the average FPS.
  91. avg_fps = total_fps / frame_count
  92. print(f"Average FPS: {avg_fps:.3f}")