@@ -39,18 +39,15 @@ class StreamGrid:
3939 video_writer: OpenCV video writer object.
4040 """
4141
42- def __init__ (self , sources = None , model = None , save = True ):
42+ def __init__ (self , sources = None , model = None , save = True , device = "cpu" ):
4343 """Initialize StreamGrid with video sources and configuration.
4444
4545 Args:
46- sources (list): List of video sources. Can be:
47- - File paths (str): "video.mp4", "stream.avi"
48- - Camera indices (int): 0, 1, 2
49- - Stream URLs (str): "rtsp://camera_url"
46+ sources (list): List of video sources. Can be, File paths (str): "video.mp4", "stream.avi",
47+ Camera indices (int): 0, 1, 2, Stream URLs (str): "rtsp://camera_url"
5048 model (optional): YOLO model instance for object detection.
51- If None, only displays video without detection.
52- save (bool, optional): Whether to save output video. Defaults to True.
53- Output will be saved as "streamgrid_output_{N}_streams.mp4".
49+ save (bool, optional): Save output video. Output will be saved as "streamgrid_output_{N}_streams.mp4".
50+ device (str, optional): Wheather to run inference on GPU or CPU device.
5451 """
5552 # GitHub repository URLs for default videos
5653 self .GITHUB_ASSETS_BASE = "https://github.com/RizwanMunawar/streamgrid/releases/download/v1.0.0/"
@@ -66,6 +63,7 @@ def __init__(self, sources=None, model=None, save=True):
6663 sources = self .get_default_videos ()
6764
6865 self .sources = sources
66+ self .device = device
6967 self .max_sources = self .batch_size = self .active_streams = len (sources )
7068 self .cols = int (math .ceil (math .sqrt (self .max_sources )))
7169 self .rows = int (math .ceil (self .max_sources / self .cols ))
@@ -238,11 +236,9 @@ def capture_video(self, source, source_id):
238236 no_frame_count = 0 # Reset retry counter on successful frame read
239237 frame_count += 1
240238 try :
241- # Non-blocking queue insertion
242- self .frame_queue .put ((source_id , frame ), timeout = 0.01 )
239+ self .frame_queue .put ((source_id , frame ), timeout = 0.01 ) # Non-blocking queue insertion
243240 except queue .Full :
244- # Drop frame if queue is full to prevent memory buildup
245- pass
241+ pass # Drop frame if queue is full to prevent memory buildup
246242 time .sleep (0.05 ) # Throttle for CPU efficiency
247243
248244 cap .release ()
@@ -257,22 +253,13 @@ def capture_video(self, source, source_id):
257253 with self .lock :
258254 self .active_streams -= 1
259255
260- def _batch_worker (self ):
261- """Process frames in batches for efficient YOLO inference.
262-
263- Collects frames from multiple sources into batches and processes them
264- together for better GPU/CPU utilization. Calculates and maintains
265- prediction FPS statistics.
266-
267- Note:
268- This method runs in a separate thread and handles all ML inference
269- and FPS calculation logic.
270- """
256+ def process_batch (self ):
257+ """Collects frames from multiple sources into batches and processes them together for better GPU/CPU
258+ utilization. Calculates and maintains prediction FPS statistics."""
271259 batch_frames , batch_ids = [], []
272260
273261 while self .running :
274- # Collect frames up to batch_size
275- while len (batch_frames ) < self .batch_size :
262+ while len (batch_frames ) < self .batch_size : # Collect frames up to batch_size
276263 try :
277264 source_id , frame = self .frame_queue .get (timeout = 0.01 )
278265 batch_frames .append (frame )
@@ -290,7 +277,7 @@ def _batch_worker(self):
290277 batch_frames ,
291278 conf = 0.25 ,
292279 verbose = False ,
293- device = 'cpu' ,
280+ device = self . device ,
294281 )
295282
296283 # Update each source with its results
@@ -400,7 +387,7 @@ def run(self):
400387 thread .start ()
401388
402389 # Start batch processing thread
403- batch_thread = threading .Thread (target = self ._batch_worker , daemon = True )
390+ batch_thread = threading .Thread (target = self .process_batch , daemon = True )
404391 batch_thread .start ()
405392
406393 # Initialize display window
0 commit comments