3131
3232
3333class DynamicThresholdCalibrator :
34- """Dynamic threshold calibrator using Inverse Power model.
34+ """Dynamic threshold calibrator using Exponential model.
3535
3636 Calibration Algorithm:
3737 1. For each threshold λ_j in threshold_trials:
3838 - Run ALL samples through forward_loop
3939 - For each sample i with length L_i, collect sparsity S_ij
4040 - Compute scale_factor_ij = λ_j × L_i
4141
42- 2. Fit Inverse Power model to ALL individual (sf_ij, S_ij) pairs:
43- scale_factor = k / (1 - sparsity)^p
42+ 2. Fit Exponential model to ALL individual (sf_ij, S_ij) pairs:
43+ scale_factor = a * exp(b * sparsity)
4444
45- 3. Return fitted k and p parameters (model-specific)
45+ 3. Return fitted a and b parameters
4646
4747 At inference time (user specifies target_sparsity S*):
48- scale_factor = k / (1 - S*)^p
48+ scale_factor = a * exp(b * S*)
4949 threshold = scale_factor / seqlen
5050
5151 Key insight: Using all individual data points (N_thresholds × N_samples)
@@ -88,20 +88,20 @@ def __init__(
8888 ]
8989
9090 def calibrate (self , model : nn .Module , forward_loop : Callable , phase : str ) -> dict [str , Any ]:
91- """Calibrate k and p parameters for Inverse Power model.
91+ """Calibrate a and b parameters for Exponential model.
9292
9393 Algorithm:
9494 1. For each threshold λ_j in threshold_trials:
9595 - Run ALL samples, collect sparsities S_ij for each sample i
9696 - Compute scale_factor_ij = λ_j × L_i (where L_i is sample length)
9797
98- 2. Fit Inverse Power model to ALL (sf_ij, S_ij) pairs:
99- scale_factor = k / (1 - sparsity)^p
98+ 2. Fit Exponential model to ALL (sf_ij, S_ij) pairs:
99+ scale_factor = a * exp(b * sparsity)
100100
101- 3. Return fitted k and p parameters
101+ 3. Return fitted a and b parameters
102102
103103 At inference time (user specifies target_sparsity S*):
104- scale_factor = k / (1 - S*)^p
104+ scale_factor = a * exp(b * S*)
105105 threshold = scale_factor / seqlen
106106
107107 Args:
@@ -110,15 +110,15 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
110110 phase: Phase to calibrate ('prefill' or 'decode')
111111
112112 Returns:
113- Dict with calibration results including k, p , r_squared, and num_data_points
113+ Dict with calibration results including a, b , r_squared, and num_data_points
114114 """
115115 # Extract attention modules
116116 attention_modules = [m for m in model .modules () if isinstance (m , SparseAttentionModule )]
117117
118118 if not attention_modules :
119119 raise ValueError ("No sparse attention modules found for calibration" )
120120
121- print (f"Starting Inverse Power model calibration ({ phase } phase)" )
121+ print (f"Starting Exponential model calibration ({ phase } phase)" )
122122 print (f"Threshold trials: { len (self .threshold_trials )} " )
123123
124124 # Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples
@@ -162,15 +162,16 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
162162
163163 print (f"Collected { len (all_data_points )} individual (scale_factor, sparsity) pairs" )
164164
165- # Stage 2: Fit Inverse Power model: scale_factor = k / (1 - sparsity)^p
166- print ("\n Stage 2: Fitting Inverse Power model to all data points..." )
165+ # Stage 2: Fit Exponential model: scale_factor = a * exp(b * sparsity)
166+ print ("\n Stage 2: Fitting Exponential model to all data points..." )
167167
168168 # Extract data for fitting
169- scale_factors = np .array ([p ["scale_factor" ] for p in all_data_points ])
170- sparsities = np .array ([p ["sparsity" ] for p in all_data_points ])
169+ scale_factors = np .array ([pt ["scale_factor" ] for pt in all_data_points ])
170+ sparsities = np .array ([pt ["sparsity" ] for pt in all_data_points ])
171171
172- # Filter out invalid sparsities (must be in (0, 1))
173- valid_mask = (sparsities > 0.01 ) & (sparsities < 0.99 )
172+ # Filter out extreme sparsities (must be in (10%, 90%))
173+ # Extreme values are unreliable for fitting
174+ valid_mask = (sparsities >= 0.10 ) & (sparsities <= 0.90 )
174175 scale_factors = scale_factors [valid_mask ]
175176 sparsities = sparsities [valid_mask ]
176177
@@ -180,44 +181,46 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
180181 )
181182 return {}
182183
183- # Define Inverse Power model: sf = k / (1 - S)^p
184- def inverse_power (sparsity , k , p ):
185- return k / np .power ( 1 - sparsity , p )
184+ # Define Exponential model: sf = a * exp(b * S)
185+ def exponential (sparsity , a , b ):
186+ return a * np .exp ( b * sparsity )
186187
187188 # Fit the model
188189 try :
189190 popt , pcov = curve_fit (
190- inverse_power ,
191+ exponential ,
191192 sparsities ,
192193 scale_factors ,
193- p0 = [100 , 1.5 ], # Initial guess
194- bounds = ([0.1 , 0.1 ], [1e7 , 10 ]), # Bounds for k and p
194+ p0 = [1.0 , 5.0 ], # Initial guess
195+ bounds = ([0.0 , 0.0 ], [np . inf , 20.0 ]), # Bounds for a and b
195196 maxfev = 10000 ,
196197 )
197- k , p = popt
198+ a , b = popt
198199 except Exception as e :
199200 warnings .warn (f"Curve fitting failed: { e } " )
200201 return {}
201202
202- # Calculate R-squared
203- pred_scale_factors = inverse_power (sparsities , k , p )
203+ # Calculate R-squared and RMSE
204+ pred_scale_factors = exponential (sparsities , a , b )
204205 ss_res = np .sum ((scale_factors - pred_scale_factors ) ** 2 )
205206 ss_tot = np .sum ((scale_factors - np .mean (scale_factors )) ** 2 )
206207 r_squared = 1 - (ss_res / ss_tot ) if ss_tot > 0 else 0
208+ rmse = np .sqrt (np .mean ((scale_factors - pred_scale_factors ) ** 2 ))
207209
208- print (f"\n { phase .capitalize ()} Calibration Results (Inverse Power Model):" )
209- print (" Model: scale_factor = k / (1 - sparsity)^p " )
210- print (f" Fitted k : { k :.4f } " )
211- print (f" Fitted p : { p :.4f} " )
210+ print (f"\n { phase .capitalize ()} Calibration Results (Exponential Model):" )
211+ print (" Model: scale_factor = a * exp(b * sparsity)" )
212+ print (f" Fitted a : { a :.6f } " )
213+ print (f" Fitted b : { b :.4f} " )
212214 print (f" R-squared: { r_squared :.6f} " )
215+ print (f" RMSE: { rmse :.2f} " )
213216 print (f" Data points used: { int (np .sum (valid_mask ))} / { len (all_data_points )} " )
214217
215218 # Show scale_factor for various target sparsities
216219 print ("\n Scale factors for different target sparsities:" )
217220 print (f" { 'Target' :<10} { 'Scale Factor' :<15} " )
218221 print (f" { '-' * 10 } { '-' * 15 } " )
219222 for target in [0.5 , 0.7 , 0.8 , 0.9 , 0.95 ]:
220- sf = k / ( 1 - target ) ** p
223+ sf = a * np . exp ( b * target )
221224 print (f" { target :<10.0%} { sf :<15.2f} " )
222225
223226 # Print calibration data summary by threshold
@@ -238,12 +241,13 @@ def inverse_power(sparsity, k, p):
238241
239242 return {
240243 "phase" : phase ,
241- "k " : float (k ),
242- "p " : float (p ),
244+ "a " : float (a ),
245+ "b " : float (b ),
243246 "r_squared" : float (r_squared ),
247+ "rmse" : float (rmse ),
244248 "num_data_points" : int (np .sum (valid_mask )),
245249 "total_samples" : len (all_data_points ),
246- "calibration_type" : "inverse_power " ,
250+ "calibration_type" : "exponential " ,
247251 }
248252
249253 def _enable_calibration_mode (self , modules : list [nn .Module ]):
0 commit comments