Skip to content

Commit c2ba88a

Browse files
Add device selection argument (#19)
1 parent 3c284df commit c2ba88a

2 files changed

Lines changed: 17 additions & 27 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ StreamGrid(model=model)
3030
sources = ["video1.mp4", "video2.mp4", "video3.mp4", "video4.mp4"]
3131
StreamGrid(sources=sources, model=model)
3232

33+
# Inference on GPU
34+
StreamGrid(sources=sources, device="cuda")
35+
3336
```
3437

3538
## Performance (Beta, final benchmarks will be released soon)

streamgrid/grid.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,15 @@ class StreamGrid:
3939
video_writer: OpenCV video writer object.
4040
"""
4141

42-
def __init__(self, sources=None, model=None, save=True):
42+
def __init__(self, sources=None, model=None, save=True, device="cpu"):
4343
"""Initialize StreamGrid with video sources and configuration.
4444
4545
Args:
46-
sources (list): List of video sources. Can be:
47-
- File paths (str): "video.mp4", "stream.avi"
48-
- Camera indices (int): 0, 1, 2
49-
- Stream URLs (str): "rtsp://camera_url"
46+
sources (list): List of video sources. Can be, File paths (str): "video.mp4", "stream.avi",
47+
Camera indices (int): 0, 1, 2, Stream URLs (str): "rtsp://camera_url"
5048
model (optional): YOLO model instance for object detection.
51-
If None, only displays video without detection.
52-
save (bool, optional): Whether to save output video. Defaults to True.
53-
Output will be saved as "streamgrid_output_{N}_streams.mp4".
49+
save (bool, optional): Save output video. Output will be saved as "streamgrid_output_{N}_streams.mp4".
50+
device (str, optional): Wheather to run inference on GPU or CPU device.
5451
"""
5552
# GitHub repository URLs for default videos
5653
self.GITHUB_ASSETS_BASE = "https://github.com/RizwanMunawar/streamgrid/releases/download/v1.0.0/"
@@ -66,6 +63,7 @@ def __init__(self, sources=None, model=None, save=True):
6663
sources = self.get_default_videos()
6764

6865
self.sources = sources
66+
self.device = device
6967
self.max_sources = self.batch_size = self.active_streams = len(sources)
7068
self.cols = int(math.ceil(math.sqrt(self.max_sources)))
7169
self.rows = int(math.ceil(self.max_sources / self.cols))
@@ -238,11 +236,9 @@ def capture_video(self, source, source_id):
238236
no_frame_count = 0 # Reset retry counter on successful frame read
239237
frame_count += 1
240238
try:
241-
# Non-blocking queue insertion
242-
self.frame_queue.put((source_id, frame), timeout=0.01)
239+
self.frame_queue.put((source_id, frame), timeout=0.01) # Non-blocking queue insertion
243240
except queue.Full:
244-
# Drop frame if queue is full to prevent memory buildup
245-
pass
241+
pass # Drop frame if queue is full to prevent memory buildup
246242
time.sleep(0.05) # Throttle for CPU efficiency
247243

248244
cap.release()
@@ -257,22 +253,13 @@ def capture_video(self, source, source_id):
257253
with self.lock:
258254
self.active_streams -= 1
259255

260-
def _batch_worker(self):
261-
"""Process frames in batches for efficient YOLO inference.
262-
263-
Collects frames from multiple sources into batches and processes them
264-
together for better GPU/CPU utilization. Calculates and maintains
265-
prediction FPS statistics.
266-
267-
Note:
268-
This method runs in a separate thread and handles all ML inference
269-
and FPS calculation logic.
270-
"""
256+
def process_batch(self):
257+
"""Collects frames from multiple sources into batches and processes them together for better GPU/CPU
258+
utilization. Calculates and maintains prediction FPS statistics."""
271259
batch_frames, batch_ids = [], []
272260

273261
while self.running:
274-
# Collect frames up to batch_size
275-
while len(batch_frames) < self.batch_size:
262+
while len(batch_frames) < self.batch_size: # Collect frames up to batch_size
276263
try:
277264
source_id, frame = self.frame_queue.get(timeout=0.01)
278265
batch_frames.append(frame)
@@ -290,7 +277,7 @@ def _batch_worker(self):
290277
batch_frames,
291278
conf=0.25,
292279
verbose=False,
293-
device='cpu',
280+
device=self.device,
294281
)
295282

296283
# Update each source with its results
@@ -400,7 +387,7 @@ def run(self):
400387
thread.start()
401388

402389
# Start batch processing thread
403-
batch_thread = threading.Thread(target=self._batch_worker, daemon=True)
390+
batch_thread = threading.Thread(target=self.process_batch, daemon=True)
404391
batch_thread.start()
405392

406393
# Initialize display window

0 commit comments

Comments
 (0)