Skip to content

Commit 4c9e1bd

Browse files
committed
Switch to exponential model for fitting from inverse power
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent d8e5d44 commit 4c9e1bd

9 files changed

Lines changed: 156 additions & 94 deletions

File tree

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def calibrate_sparse_attention(
350350
)
351351
prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill")
352352

353-
if "k" in prefill_result and "p" in prefill_result:
353+
if "a" in prefill_result and "b" in prefill_result:
354354
calibration_results["prefill"] = prefill_result
355355
else:
356356
warnings.warn("Prefill calibration did not produce valid results")
@@ -372,7 +372,7 @@ def calibrate_sparse_attention(
372372
)
373373
decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode")
374374

375-
if "k" in decode_result and "p" in decode_result:
375+
if "a" in decode_result and "b" in decode_result:
376376
calibration_results["decode"] = decode_result
377377
else:
378378
warnings.warn("Decode calibration did not produce valid results")
@@ -382,14 +382,14 @@ def calibrate_sparse_attention(
382382
warnings.warn("No calibration produced valid results")
383383
return {}
384384

385-
# Extract k and p for each phase
385+
# Extract a and b for each phase
386386
calibration_params: dict[str, dict[str, float]] = {}
387387
for phase in ["prefill", "decode"]:
388388
if phase in calibration_results:
389389
result = calibration_results[phase]
390390
calibration_params[phase] = {
391-
"k": result["k"],
392-
"p": result["p"],
391+
"a": result["a"],
392+
"b": result["b"],
393393
}
394394

395395
# Apply calibration params to all modules
@@ -400,7 +400,7 @@ def calibrate_sparse_attention(
400400
for phase, params in calibration_params.items():
401401
result = calibration_results[phase]
402402
print(f" {phase}:")
403-
print(f" Model: scale_factor = {params['k']:.4f} / (1 - sparsity)^{params['p']:.4f}")
403+
print(f" Model: scale_factor = {params['a']:.6f} * exp({params['b']:.4f} * sparsity)")
404404
print(f" R-squared: {result['r_squared']:.6f}")
405405

406406
for module_name, module in sparse_modules:

modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,21 @@
3131

3232

3333
class DynamicThresholdCalibrator:
34-
"""Dynamic threshold calibrator using Inverse Power model.
34+
"""Dynamic threshold calibrator using Exponential model.
3535
3636
Calibration Algorithm:
3737
1. For each threshold λ_j in threshold_trials:
3838
- Run ALL samples through forward_loop
3939
- For each sample i with length L_i, collect sparsity S_ij
4040
- Compute scale_factor_ij = λ_j × L_i
4141
42-
2. Fit Inverse Power model to ALL individual (sf_ij, S_ij) pairs:
43-
scale_factor = k / (1 - sparsity)^p
42+
2. Fit Exponential model to ALL individual (sf_ij, S_ij) pairs:
43+
scale_factor = a * exp(b * sparsity)
4444
45-
3. Return fitted k and p parameters (model-specific)
45+
3. Return fitted a and b parameters
4646
4747
At inference time (user specifies target_sparsity S*):
48-
scale_factor = k / (1 - S*)^p
48+
scale_factor = a * exp(b * S*)
4949
threshold = scale_factor / seqlen
5050
5151
Key insight: Using all individual data points (N_thresholds × N_samples)
@@ -88,20 +88,20 @@ def __init__(
8888
]
8989

9090
def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]:
91-
"""Calibrate k and p parameters for Inverse Power model.
91+
"""Calibrate a and b parameters for Exponential model.
9292
9393
Algorithm:
9494
1. For each threshold λ_j in threshold_trials:
9595
- Run ALL samples, collect sparsities S_ij for each sample i
9696
- Compute scale_factor_ij = λ_j × L_i (where L_i is sample length)
9797
98-
2. Fit Inverse Power model to ALL (sf_ij, S_ij) pairs:
99-
scale_factor = k / (1 - sparsity)^p
98+
2. Fit Exponential model to ALL (sf_ij, S_ij) pairs:
99+
scale_factor = a * exp(b * sparsity)
100100
101-
3. Return fitted k and p parameters
101+
3. Return fitted a and b parameters
102102
103103
At inference time (user specifies target_sparsity S*):
104-
scale_factor = k / (1 - S*)^p
104+
scale_factor = a * exp(b * S*)
105105
threshold = scale_factor / seqlen
106106
107107
Args:
@@ -110,15 +110,15 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
110110
phase: Phase to calibrate ('prefill' or 'decode')
111111
112112
Returns:
113-
Dict with calibration results including k, p, r_squared, and num_data_points
113+
Dict with calibration results including a, b, r_squared, and num_data_points
114114
"""
115115
# Extract attention modules
116116
attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
117117

118118
if not attention_modules:
119119
raise ValueError("No sparse attention modules found for calibration")
120120

121-
print(f"Starting Inverse Power model calibration ({phase} phase)")
121+
print(f"Starting Exponential model calibration ({phase} phase)")
122122
print(f"Threshold trials: {len(self.threshold_trials)}")
123123

124124
# Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples
@@ -162,15 +162,16 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
162162

163163
print(f"Collected {len(all_data_points)} individual (scale_factor, sparsity) pairs")
164164

165-
# Stage 2: Fit Inverse Power model: scale_factor = k / (1 - sparsity)^p
166-
print("\nStage 2: Fitting Inverse Power model to all data points...")
165+
# Stage 2: Fit Exponential model: scale_factor = a * exp(b * sparsity)
166+
print("\nStage 2: Fitting Exponential model to all data points...")
167167

168168
# Extract data for fitting
169-
scale_factors = np.array([p["scale_factor"] for p in all_data_points])
170-
sparsities = np.array([p["sparsity"] for p in all_data_points])
169+
scale_factors = np.array([pt["scale_factor"] for pt in all_data_points])
170+
sparsities = np.array([pt["sparsity"] for pt in all_data_points])
171171

172-
# Filter out invalid sparsities (must be in (0, 1))
173-
valid_mask = (sparsities > 0.01) & (sparsities < 0.99)
172+
# Filter out extreme sparsities (must be in (10%, 90%))
173+
# Extreme values are unreliable for fitting
174+
valid_mask = (sparsities >= 0.10) & (sparsities <= 0.90)
174175
scale_factors = scale_factors[valid_mask]
175176
sparsities = sparsities[valid_mask]
176177

@@ -180,44 +181,46 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
180181
)
181182
return {}
182183

183-
# Define Inverse Power model: sf = k / (1 - S)^p
184-
def inverse_power(sparsity, k, p):
185-
return k / np.power(1 - sparsity, p)
184+
# Define Exponential model: sf = a * exp(b * S)
185+
def exponential(sparsity, a, b):
186+
return a * np.exp(b * sparsity)
186187

187188
# Fit the model
188189
try:
189190
popt, pcov = curve_fit(
190-
inverse_power,
191+
exponential,
191192
sparsities,
192193
scale_factors,
193-
p0=[100, 1.5], # Initial guess
194-
bounds=([0.1, 0.1], [1e7, 10]), # Bounds for k and p
194+
p0=[1.0, 5.0], # Initial guess
195+
bounds=([0.0, 0.0], [np.inf, 20.0]), # Bounds for a and b
195196
maxfev=10000,
196197
)
197-
k, p = popt
198+
a, b = popt
198199
except Exception as e:
199200
warnings.warn(f"Curve fitting failed: {e}")
200201
return {}
201202

202-
# Calculate R-squared
203-
pred_scale_factors = inverse_power(sparsities, k, p)
203+
# Calculate R-squared and RMSE
204+
pred_scale_factors = exponential(sparsities, a, b)
204205
ss_res = np.sum((scale_factors - pred_scale_factors) ** 2)
205206
ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2)
206207
r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
208+
rmse = np.sqrt(np.mean((scale_factors - pred_scale_factors) ** 2))
207209

208-
print(f"\n{phase.capitalize()} Calibration Results (Inverse Power Model):")
209-
print(" Model: scale_factor = k / (1 - sparsity)^p")
210-
print(f" Fitted k: {k:.4f}")
211-
print(f" Fitted p: {p:.4f}")
210+
print(f"\n{phase.capitalize()} Calibration Results (Exponential Model):")
211+
print(" Model: scale_factor = a * exp(b * sparsity)")
212+
print(f" Fitted a: {a:.6f}")
213+
print(f" Fitted b: {b:.4f}")
212214
print(f" R-squared: {r_squared:.6f}")
215+
print(f" RMSE: {rmse:.2f}")
213216
print(f" Data points used: {int(np.sum(valid_mask))} / {len(all_data_points)}")
214217

215218
# Show scale_factor for various target sparsities
216219
print("\nScale factors for different target sparsities:")
217220
print(f" {'Target':<10} {'Scale Factor':<15}")
218221
print(f" {'-' * 10} {'-' * 15}")
219222
for target in [0.5, 0.7, 0.8, 0.9, 0.95]:
220-
sf = k / (1 - target) ** p
223+
sf = a * np.exp(b * target)
221224
print(f" {target:<10.0%} {sf:<15.2f}")
222225

223226
# Print calibration data summary by threshold
@@ -238,12 +241,13 @@ def inverse_power(sparsity, k, p):
238241

239242
return {
240243
"phase": phase,
241-
"k": float(k),
242-
"p": float(p),
244+
"a": float(a),
245+
"b": float(b),
243246
"r_squared": float(r_squared),
247+
"rmse": float(rmse),
244248
"num_data_points": int(np.sum(valid_mask)),
245249
"total_samples": len(all_data_points),
246-
"calibration_type": "inverse_power",
250+
"calibration_type": "exponential",
247251
}
248252

249253
def _enable_calibration_mode(self, modules: list[nn.Module]):

modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,8 @@ def find_optimal_haystack_size(
461461
upper_bound = max(estimated_max, incremental * 2)
462462
optimal_num_haystack = None
463463

464-
logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack")
465-
logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}")
464+
logger.debug(f"Estimated {tokens_per_haystack:.1f} tokens per haystack")
465+
logger.debug(f"Binary search bounds: {lower_bound} to {upper_bound}")
466466

467467
while lower_bound <= upper_bound:
468468
mid = (lower_bound + upper_bound) // 2
@@ -486,6 +486,6 @@ def find_optimal_haystack_size(
486486
upper_bound = mid - 1
487487

488488
final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental
489-
logger.info(f"Optimal haystack size: {final_size}")
489+
logger.debug(f"Optimal haystack size: {final_size}")
490490

491491
return final_size

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ def validate_threshold(cls, v):
147147
class CalibrationConfig(ModeloptBaseConfig):
148148
"""Configuration for automatic threshold calibration using RULER dataset.
149149
150-
Calibration fits an Inverse Power model to determine dynamic thresholds that
151-
achieve target sparsity. The model learns parameters k and p per phase:
150+
Calibration fits an Exponential model to determine dynamic thresholds that
151+
achieve target sparsity. The model learns parameters a and b per phase:
152152
153-
scale_factor = k / (1 - target_sparsity)^p
153+
scale_factor = a * exp(b * target_sparsity)
154154
155155
At inference time, the threshold is computed as:
156156
@@ -160,6 +160,7 @@ class CalibrationConfig(ModeloptBaseConfig):
160160
- Target sparsity can be changed at runtime without recalibration
161161
- Threshold automatically adapts to sequence length
162162
- Supports independent prefill and decode phase calibration
163+
- Exponential model provides better fit (lower RMSE)
163164
"""
164165

165166
target_sparse_ratio: dict[str, float] = ModeloptField(

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,28 +244,85 @@ def update_sparse_attention_metadata(
244244
def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None:
245245
"""Extract sparse attention config for export to config.json.
246246
247-
Extracts the calibration parameters (k, p) and target_sparse_ratio from the first
248-
sparse attention module that has calibrated thresholds.
247+
Extracts the calibration parameters (a, b) for the exponential threshold model
248+
from the first sparse attention module that has calibrated thresholds.
249+
250+
The exported config allows computing threshold at runtime:
251+
scale_factor = a * exp(b * target_sparsity)
252+
threshold = scale_factor / seqlen
249253
250254
Args:
251255
model: Model with sparse attention applied
252256
253257
Returns:
254-
Dictionary with sparse attention config, or None if no calibrated config found.
255-
Contains "calibration_params" with k and p per phase, and "target_sparse_ratio".
258+
Dictionary with sparse attention config for HuggingFace config.json export.
259+
Returns None if no calibrated sparse attention modules found.
260+
261+
Example output::
262+
263+
{
264+
"config_groups": {
265+
"group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]}
266+
},
267+
"threshold_scale_factor": {
268+
"formula": "a * exp(b * target_sparsity)",
269+
"prefill": {"a": 7.93, "b": 8.61},
270+
"decode": {"a": 0.12, "b": 9.85},
271+
},
272+
"producer": {"name": "modelopt", "version": "0.37.0"},
273+
}
256274
"""
275+
import modelopt
276+
277+
# Collect sparse attention module info
278+
calibration_params = None
279+
target_classes: set[str] = set()
280+
257281
for module in model.modules():
258282
if isinstance(module, SparseAttentionModule):
259-
calibration_params = getattr(module._sparse_method_instance, "calibration_params", None)
260-
target_sparse_ratio = getattr(
261-
module._sparse_method_instance, "target_sparse_ratio", None
262-
)
263-
if calibration_params is not None:
264-
return {
265-
"calibration_params": calibration_params,
266-
"target_sparse_ratio": target_sparse_ratio,
267-
}
268-
return None
283+
# Get the original wrapped module's class name
284+
if hasattr(module, "get_original_cls_by_level"):
285+
original_cls = module.get_original_cls_by_level(level=0)
286+
if original_cls is not None:
287+
target_classes.add(original_cls.__name__)
288+
289+
# Get calibration params from first module that has them
290+
if calibration_params is None:
291+
calibration_params = getattr(
292+
module._sparse_method_instance, "calibration_params", None
293+
)
294+
295+
# Return None if no calibration params found
296+
if calibration_params is None:
297+
return None
298+
299+
# Build threshold_scale_factor with model parameters
300+
threshold_scale_factor: dict[str, Any] = {
301+
"formula": "a * exp(b * target_sparsity)",
302+
}
303+
for phase in ["prefill", "decode"]:
304+
if phase in calibration_params:
305+
threshold_scale_factor[phase] = {
306+
"a": calibration_params[phase]["a"],
307+
"b": calibration_params[phase]["b"],
308+
}
309+
310+
# Build the export config
311+
export_config: dict[str, Any] = {
312+
"config_groups": {
313+
"group_0": {
314+
"sparse_algo": "softmax_skip",
315+
"targets": sorted(target_classes) if target_classes else ["Attention"],
316+
}
317+
},
318+
"threshold_scale_factor": threshold_scale_factor,
319+
"producer": {
320+
"name": "modelopt",
321+
"version": modelopt.__version__,
322+
},
323+
}
324+
325+
return export_config
269326

270327

271328
def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable):
@@ -332,15 +389,15 @@ def _format_threshold(info: dict) -> str:
332389
"""Format threshold info for display."""
333390
t = info.get("type")
334391
if t == "dynamic_calibrated":
335-
# Inverse Power model: threshold = k / (1 - sparsity)^p / seqlen
392+
# Exponential model: threshold = a * exp(b * sparsity) / seqlen
336393
params = info.get("calibration_params", {})
337394
target = info.get("target_sparse_ratio", {})
338395
parts = []
339396
for phase in ["prefill", "decode"]:
340397
if phase in params:
341-
k, p = params[phase]["k"], params[phase]["p"]
398+
a, b = params[phase]["a"], params[phase]["b"]
342399
s = target.get(phase, 0.5)
343-
parts.append(f"{phase}: k={k:.1f}, p={p:.2f}, target={s:.0%}")
400+
parts.append(f"{phase}: a={a:.4f}, b={b:.2f}, target={s:.0%}")
344401
return f"calibrated({', '.join(parts)})"
345402
if t == "static":
346403
v = info.get("value")

0 commit comments

Comments
 (0)