yolov8_region_counter.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import argparse
  3. from collections import defaultdict
  4. from pathlib import Path
  5. import cv2
  6. import numpy as np
  7. from shapely.geometry import Polygon
  8. from shapely.geometry.point import Point
  9. from ultralytics import YOLO
  10. from ultralytics.utils.files import increment_path
  11. from ultralytics.utils.plotting import Annotator, colors
  12. track_history = defaultdict(list)
  13. current_region = None
  14. counting_regions = [
  15. {
  16. "name": "YOLOv8 Rectangle Region",
  17. "polygon": Polygon([(0, 0), (0, 1280), (720, 1280), (720, 0)]), # Polygon points (tl,bl,br,tr)
  18. "counts": 0,
  19. "dragging": False,
  20. "region_color": (37, 255, 225), # BGR Value
  21. "text_color": (0, 0, 0), # Region Text Color
  22. },
  23. ]
  24. def mouse_callback(event, x, y, flags, param):
  25. """
  26. Handles mouse events for region manipulation.
  27. Parameters:
  28. event (int): The mouse event type (e.g., cv2.EVENT_LBUTTONDOWN).
  29. x (int): The x-coordinate of the mouse pointer.
  30. y (int): The y-coordinate of the mouse pointer.
  31. flags (int): Additional flags passed by OpenCV.
  32. param: Additional parameters passed to the callback (not used in this function).
  33. Global Variables:
  34. current_region (dict): A dictionary representing the current selected region.
  35. Mouse Events:
  36. - LBUTTONDOWN: Initiates dragging for the region containing the clicked point.
  37. - MOUSEMOVE: Moves the selected region if dragging is active.
  38. - LBUTTONUP: Ends dragging for the selected region.
  39. Notes:
  40. - This function is intended to be used as a callback for OpenCV mouse events.
  41. - Requires the existence of the 'counting_regions' list and the 'Polygon' class.
  42. Example:
  43. >>> cv2.setMouseCallback(window_name, mouse_callback)
  44. """
  45. global current_region
  46. # Mouse left button down event
  47. if event == cv2.EVENT_LBUTTONDOWN:
  48. for region in counting_regions:
  49. if region["polygon"].contains(Point((x, y))):
  50. current_region = region
  51. current_region["dragging"] = True
  52. current_region["offset_x"] = x
  53. current_region["offset_y"] = y
  54. # Mouse move event
  55. elif event == cv2.EVENT_MOUSEMOVE:
  56. if current_region is not None and current_region["dragging"]:
  57. dx = x - current_region["offset_x"]
  58. dy = y - current_region["offset_y"]
  59. current_region["polygon"] = Polygon(
  60. [(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords]
  61. )
  62. current_region["offset_x"] = x
  63. current_region["offset_y"] = y
  64. # Mouse left button up event
  65. elif event == cv2.EVENT_LBUTTONUP:
  66. if current_region is not None and current_region["dragging"]:
  67. current_region["dragging"] = False
  68. def run(
  69. weights="yolov8n.pt",
  70. source=None,
  71. device="cpu",
  72. view_img=False,
  73. save_img=False,
  74. exist_ok=False,
  75. classes=None,
  76. line_thickness=2,
  77. track_thickness=2,
  78. region_thickness=2,
  79. ):
  80. """
  81. Run Region counting on a video using YOLOv8 and ByteTrack.
  82. Supports movable region for real time counting inside specific area.
  83. Supports multiple regions counting.
  84. Regions can be Polygons or rectangle in shape
  85. Args:
  86. weights (str): Model weights path.
  87. source (str): Video file path.
  88. device (str): processing device cpu, 0, 1
  89. view_img (bool): Show results.
  90. save_img (bool): Save results.
  91. exist_ok (bool): Overwrite existing files.
  92. classes (list): classes to detect and track
  93. line_thickness (int): Bounding box thickness.
  94. track_thickness (int): Tracking line thickness
  95. region_thickness (int): Region thickness.
  96. """
  97. vid_frame_count = 0
  98. # Check source path
  99. if not Path(source).exists():
  100. raise FileNotFoundError(f"Source path '{source}' does not exist.")
  101. # Setup Model
  102. model = YOLO(f"{weights}")
  103. model.to("cuda") if device == "0" else model.to("cpu")
  104. # Extract classes names
  105. names = model.model.names
  106. # Video setup
  107. videocapture = cv2.VideoCapture(source)
  108. frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
  109. fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
  110. # Output setup
  111. save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok)
  112. save_dir.mkdir(parents=True, exist_ok=True)
  113. video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
  114. # Iterate over video frames
  115. while videocapture.isOpened():
  116. success, frame = videocapture.read()
  117. if not success:
  118. break
  119. vid_frame_count += 1
  120. # Extract the results
  121. results = model.track(frame, persist=True, classes=classes)
  122. if results[0].boxes.id is not None:
  123. boxes = results[0].boxes.xyxy.cpu()
  124. track_ids = results[0].boxes.id.int().cpu().tolist()
  125. clss = results[0].boxes.cls.cpu().tolist()
  126. annotator = Annotator(frame, line_width=line_thickness, example=str(names))
  127. for box, track_id, cls in zip(boxes, track_ids, clss):
  128. annotator.box_label(box, str(names[cls]), color=colors(cls, True))
  129. bbox_center = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 # Bbox center
  130. track = track_history[track_id] # Tracking Lines plot
  131. track.append((float(bbox_center[0]), float(bbox_center[1])))
  132. if len(track) > 30:
  133. track.pop(0)
  134. points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
  135. cv2.polylines(frame, [points], isClosed=False, color=colors(cls, True), thickness=track_thickness)
  136. # Check if detection inside region
  137. for region in counting_regions:
  138. if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))):
  139. region["counts"] += 1
  140. # Draw regions (Polygons/Rectangles)
  141. for region in counting_regions:
  142. region_label = str(region["counts"])
  143. region_color = region["region_color"]
  144. region_text_color = region["text_color"]
  145. polygon_coords = np.array(region["polygon"].exterior.coords, dtype=np.int32)
  146. centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y)
  147. text_size, _ = cv2.getTextSize(
  148. region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness
  149. )
  150. text_x = centroid_x - text_size[0] // 2
  151. text_y = centroid_y + text_size[1] // 2
  152. cv2.rectangle(
  153. frame,
  154. (text_x - 5, text_y - text_size[1] - 5),
  155. (text_x + text_size[0] + 5, text_y + 5),
  156. region_color,
  157. -1,
  158. )
  159. cv2.putText(
  160. frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness
  161. )
  162. cv2.polylines(frame, [polygon_coords], isClosed=True, color=region_color, thickness=region_thickness)
  163. if view_img:
  164. if vid_frame_count == 1:
  165. cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable")
  166. cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback)
  167. cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame)
  168. if save_img:
  169. video_writer.write(frame)
  170. for region in counting_regions: # Reinitialize count for each region
  171. region["counts"] = 0
  172. if cv2.waitKey(1) & 0xFF == ord("q"):
  173. break
  174. del vid_frame_count
  175. video_writer.release()
  176. videocapture.release()
  177. cv2.destroyAllWindows()
  178. def parse_opt():
  179. """Parse command line arguments."""
  180. parser = argparse.ArgumentParser()
  181. parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
  182. parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
  183. parser.add_argument("--source", type=str, required=True, help="video file path")
  184. parser.add_argument("--view-img", action="store_true", help="show results")
  185. parser.add_argument("--save-img", action="store_true", help="save results")
  186. parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
  187. parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
  188. parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness")
  189. parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness")
  190. parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness")
  191. return parser.parse_args()
  192. def main(opt):
  193. """Main function."""
  194. run(**vars(opt))
  195. if __name__ == "__main__":
  196. opt = parse_opt()
  197. main(opt)