77import numpy as np
88from collections import deque
99from .stream import StreamManager
10- from .plot import StreamAnnotator
10+ from .plotting import StreamPlotter
1111from .utils import LOGGER , get_optimal_grid_size
1212from .analytics import StreamAnalytics
1313
1414
1515class StreamGrid :
1616 """Ultra-fast multi-stream video display with object detection."""
1717
18- def __init__ (
19- self , sources = None , model = None , save = True , device = "cpu" , analytics = False
20- ):
18+ def __init__ (self , sources = None , model = None , save = True , device = "cpu" , analytics = False ):
2119 # Initialize components
2220 self .stream_manager = StreamManager (sources )
2321 self .model = model
@@ -30,13 +28,11 @@ def __init__(
3028 self .rows = int (math .ceil (self .max_sources / self .cols ))
3129 self .cell_w , self .cell_h = get_optimal_grid_size (self .max_sources , self .cols )
3230
33- # Initialize stream plotter
34- self .plotter = StreamAnnotator (self .cell_w , self .cell_h )
31+ # Initialize plotter
32+ self .plotter = StreamPlotter (self .cell_w , self .cell_h , self . max_sources )
3533
3634 # Display state
37- self .grid = np .zeros (
38- (self .rows * self .cell_h , self .cols * self .cell_w , 3 ), dtype = np .uint8
39- )
35+ self .grid = np .zeros ((self .rows * self .cell_h , self .cols * self .cell_w , 3 ), dtype = np .uint8 )
4036 self .frames = {}
4137 self .show_stats = True
4238 self .running = False
@@ -55,17 +51,26 @@ def setup_video_writer(self):
5551 """Setup video writer for saving output."""
5652 fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
5753 return cv2 .VideoWriter (
58- f"streamgrid_output_{ self .max_sources } _streams.mp4" ,
59- fourcc ,
60- 30 ,
61- (self .cols * self .cell_w , self .rows * self .cell_h ),
54+ f"streamgrid_output_{ self .max_sources } _streams.mp4" , fourcc , 30 ,
55+ (self .cols * self .cell_w , self .rows * self .cell_h )
6256 )
6357
6458 def process_batch (self ):
65- """Process frames in batches for better performance."""
59+ """Process frames in batches with consistent timing."""
60+ batch_interval = 0.033 # ~30 FPS processing
61+ last_batch_time = time .time ()
62+
6663 while self .running :
64+ current_time = time .time ()
65+
66+ # Maintain consistent batch processing rate
67+ if current_time - last_batch_time < batch_interval :
68+ time .sleep (0.001 )
69+ continue
70+
6771 frame_data = self .stream_manager .get_frames (self .max_sources )
6872 if not frame_data :
73+ time .sleep (0.001 ) # Small sleep if no frames
6974 continue
7075
7176 batch_start = time .time ()
@@ -74,9 +79,7 @@ def process_batch(self):
7479
7580 # Run inference if model available
7681 if self .model :
77- results = self .model .predict (
78- frames , conf = 0.25 , verbose = False , device = self .device , batch = 16
79- )
82+ results = self .model .predict (frames , conf = 0.25 , verbose = False , device = self .device )
8083 for source_id , frame , result in zip (ids , frames , results ):
8184 self .update_source (source_id , frame , result )
8285 else :
@@ -85,6 +88,7 @@ def process_batch(self):
8588
8689 # Update performance metrics
8790 self .update_fps (len (frames ), time .time () - batch_start )
91+ last_batch_time = current_time
8892
8993 def update_fps (self , frame_count , batch_time ):
9094 """Update FPS calculations."""
@@ -106,14 +110,10 @@ def update_source(self, source_id, frame, results=None):
106110 detections = 0
107111 if results and results .boxes is not None :
108112 detections = len (results .boxes )
109- resized = self .plotter .draw_detections (
110- resized , results , frame .shape [:2 ]
111- )
113+ resized = self .plotter .draw_detections (resized , results , frame .shape [:2 ])
112114
113115 # Add source label
114- resized = self .plotter .draw_source_label (
115- resized , source_id , self .show_stats
116- )
116+ resized = self .plotter .draw_source_label (resized , source_id , self .show_stats )
117117
118118 # Store processed frame
119119 self .frames [source_id ] = resized
@@ -140,10 +140,8 @@ def update_display(self):
140140 # Add FPS overlay
141141 if self .show_stats :
142142 self .grid = self .plotter .draw_fps_overlay (
143- self .grid ,
144- self .prediction_fps ,
145- self .cols * self .cell_w ,
146- self .rows * self .cell_h ,
143+ self .grid , self .prediction_fps ,
144+ self .cols * self .cell_w , self .rows * self .cell_h
147145 )
148146
149147 # Display and save
@@ -168,7 +166,7 @@ def run(self):
168166 key = cv2 .waitKey (1 ) & 0xFF
169167 if key == 27 : # ESC
170168 break
171- elif key == ord ("s" ):
169+ elif key == ord ('s' ):
172170 self .show_stats = not self .show_stats
173171 finally :
174172 self .stop ()
@@ -187,9 +185,7 @@ def stop(self):
187185 self .analytics .summary ()
188186 if self .video_writer :
189187 self .video_writer .release ()
190- LOGGER .info (
191- f"✅ Video saved: streamgrid_output_{ self .max_sources } _streams.mp4"
192- )
188+ LOGGER .info (f"✅ Video saved: streamgrid_output_{ self .max_sources } _streams.mp4" )
193189
194190 with self .lock :
195191 self .frames .clear ()
0 commit comments