yolov7-pose.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import time
  2. import torch
  3. import cv2
  4. import numpy as np
  5. from torchvision import transforms
  6. from utils.datasets import letterbox
  7. from utils.general import non_max_suppression_kpt
  8. from utils.plots import output_to_keypoint, plot_skeleton_kpts
  9. def pose_video(frame):
  10. mapped_img = frame.copy()
  11. # Letterbox resizing.
  12. img = letterbox(frame, input_size, stride=64, auto=True)[0]
  13. print(img.shape)
  14. img_ = img.copy()
  15. # Convert the array to 4D.
  16. img = transforms.ToTensor()(img)
  17. # Convert the array to Tensor.
  18. img = torch.tensor(np.array([img.numpy()]))
  19. # Load the image into the computation device.
  20. img = img.to(device)
  21. # Gradients are stored during training, not required while inference.
  22. with torch.no_grad():
  23. t1 = time.time()
  24. output, _ = model(img)
  25. t2 = time.time()
  26. fps = 1/(t2 - t1)
  27. output = non_max_suppression_kpt(output,
  28. 0.25, # Conf. Threshold.
  29. 0.65, # IoU Threshold.
  30. nc=1, # Number of classes.
  31. nkpt=17, # Number of keypoints.
  32. kpt_label=True)
  33. output = output_to_keypoint(output)
  34. # Change format [b, c, h, w] to [h, w, c] for displaying the image.
  35. nimg = img[0].permute(1, 2, 0) * 255
  36. nimg = nimg.cpu().numpy().astype(np.uint8)
  37. nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
  38. for idx in range(output.shape[0]):
  39. plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
  40. return nimg, fps
  41. #------------------------------------------------------------------------------#
  42. # Change forward pass input size.
  43. input_size = 256
  44. #---------------------------INITIALIZATIONS------------------------------------#
  45. # Select the device based on hardware configs.
  46. if torch.cuda.is_available():
  47. device = torch.device("cuda:0")
  48. else:
  49. device = torch.device("cpu")
  50. print('Selected Device : ', device)
  51. # Load keypoint detection model.
  52. weights = torch.load('yolov7-w6-pose.pt', map_location=device)
  53. model = weights['model']
  54. # Load the model in evaluation mode.
  55. _ = model.float().eval()
  56. # Load the model to computation device [cpu/gpu/tpu]
  57. model.to(device)
  58. # Video capture and writer init.
  59. videos = ['dance',
  60. 'dark',
  61. 'far-away',
  62. 'occlusion-example',
  63. 'skydiving',
  64. 'yoga-1']
  65. file_name = videos[0] + '.mp4'
  66. vid_path = '../Media/' + file_name
  67. cap = cv2.VideoCapture(vid_path)
  68. fps = int(cap.get(cv2.CAP_PROP_FPS))
  69. ret, frame = cap.read()
  70. h, w, _ = frame.shape
  71. # May need to change the w, h as letterbox function reshapes the image.
  72. out = cv2.VideoWriter('pose_outputs/' + file_name,
  73. cv2.VideoWriter_fourcc(*'mp4v'),
  74. fps, (w, h))
  75. #-------------------------------------------------------------------------------#
  76. if __name__ == '__main__':
  77. while True:
  78. ret, frame = cap.read()
  79. if not ret:
  80. print('Unable to read frame. Exiting ..')
  81. break
  82. img, fps_ = pose_video(frame)
  83. cv2.putText(img, 'FPS : {:.2f}'.format(fps_), (200, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2, cv2.LINE_AA)
  84. cv2.putText(img, 'YOLOv7', (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2, cv2.LINE_AA)
  85. cv2.imshow('Output', img[...,::-1])
  86. out.write(img[...,::-1])
  87. key = cv2.waitKey(1)
  88. if key == ord('q'):
  89. break
  90. cap.release()
  91. out.release()
  92. cv2.destroyAllWindows()