Skip to content

Commit d8e5d44

Browse files
committed
Implement Inverse Power calibration for sparse attention
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent e915ff0 commit d8e5d44

11 files changed

Lines changed: 432 additions & 310 deletions

File tree

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
"""Example script for applying sparse attention to HuggingFace models."""
1818

1919
import argparse
20+
import copy
2021
import random
2122
from pathlib import Path
2223

2324
import numpy as np
2425
import torch
25-
from datasets import load_dataset
2626
from transformers import AutoModelForCausalLM, AutoTokenizer
2727

2828
import modelopt.torch.opt as mto
@@ -46,41 +46,13 @@
4646
}
4747

4848

49-
def get_narrativeqa_samples(num_samples=3):
50-
"""Load samples from NarrativeQA dataset for testing.
51-
52-
Args:
53-
num_samples: Number of samples to generate
54-
55-
Raises:
56-
RuntimeError: If dataset loading fails
57-
ValueError: If no valid samples could be loaded
58-
"""
59-
# Load NarrativeQA dataset with retry logic
60-
try:
61-
dataset = load_dataset("narrativeqa", split="test", streaming=True)
62-
except Exception as e:
63-
raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}")
64-
65-
samples = []
66-
for i, item in enumerate(dataset):
67-
if i >= num_samples:
68-
break
69-
70-
# Combine document context and question
71-
context = item.get("document", {}).get("text", "")
72-
question = item.get("question", {}).get("text", "")
73-
74-
if context and question:
75-
# Use the full context as-is
76-
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
77-
samples.append(prompt)
78-
79-
if not samples:
80-
raise ValueError("Could not load NarrativeQA samples")
81-
82-
print(f"Loaded {len(samples)} NarrativeQA samples")
83-
return samples
49+
def get_test_prompts():
50+
"""Get simple test prompts for sample output generation."""
51+
return [
52+
"What is the capital of France? Answer:",
53+
"Explain the theory of relativity in simple terms:",
54+
"Write a short poem about the ocean:",
55+
]
8456

8557

8658
def truncate_text(text: str, tokenizer, max_length: int):
@@ -130,7 +102,7 @@ def generate_sample_output(model, tokenizer, args):
130102
Tuple of (generated_text, input_prompt, input_ids)
131103
"""
132104
# Load test sample
133-
prompts = get_narrativeqa_samples(num_samples=1)
105+
prompts = get_test_prompts()
134106
prompt = prompts[0]
135107

136108
# Prepare inputs
@@ -198,6 +170,20 @@ def main(args):
198170
# Apply sparse attention with optional calibration
199171
print(f"\nApplying sparse attention: {args.sparse_attn}")
200172
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
173+
174+
# Override target_sparse_ratio if provided via CLI
175+
if args.target_sparse_ratio is not None:
176+
sparse_config = copy.deepcopy(sparse_config)
177+
sparse_cfg = sparse_config.get("sparse_cfg", {})
178+
if isinstance(sparse_cfg, dict) and "calibration" in sparse_cfg:
179+
calibration_cfg = sparse_cfg["calibration"]
180+
if isinstance(calibration_cfg, dict):
181+
calibration_cfg["target_sparse_ratio"] = {
182+
"prefill": args.target_sparse_ratio,
183+
"decode": args.target_sparse_ratio,
184+
}
185+
print(f"Overriding target_sparse_ratio to {args.target_sparse_ratio}")
186+
201187
model = mtsa.sparsify(model, config=sparse_config)
202188
print("Sparse attention applied successfully!")
203189

@@ -287,5 +273,13 @@ def main(args):
287273
help="Directory to export the model with sparse attention applied",
288274
)
289275

276+
# Calibration arguments
277+
parser.add_argument(
278+
"--target_sparse_ratio",
279+
type=float,
280+
default=None,
281+
help="Target sparsity ratio for calibration (0.0 to 1.0). Overrides config value.",
282+
)
283+
290284
args = parser.parse_args()
291285
main(args)

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
"""Calibration functions for sparse attention."""
1717

18+
import hashlib
19+
import json
1820
import warnings
1921
from collections.abc import Callable
22+
from pathlib import Path
2023
from typing import Any
2124

2225
import torch
@@ -30,6 +33,54 @@
3033
from .dataset import RulerDatasetBuilder
3134

3235

36+
def _get_cache_path(
37+
tokenizer_path: str, samples: int, max_seqlen: int, cache_dir: str | None = None
38+
) -> Path:
39+
"""Generate cache file path based on calibration parameters.
40+
41+
Args:
42+
tokenizer_path: Path to tokenizer (used in hash)
43+
samples: Number of calibration samples
44+
max_seqlen: Maximum sequence length
45+
cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/
46+
"""
47+
# Create a hash of the parameters for the cache filename
48+
key = f"{tokenizer_path}_{samples}_{max_seqlen}"
49+
hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12]
50+
filename = f"ruler_cache_{samples}s_{max_seqlen}l_{hash_str}.json"
51+
52+
if cache_dir:
53+
base_dir = Path(cache_dir)
54+
else:
55+
base_dir = Path.home() / ".cache" / "modelopt" / "sparse_attention"
56+
57+
return base_dir / filename
58+
59+
60+
def _load_cached_data(cache_path: Path) -> list[dict[str, Any]] | None:
61+
"""Load calibration data from cache if it exists."""
62+
if cache_path.exists():
63+
try:
64+
with open(cache_path) as f:
65+
data = json.load(f)
66+
print(f"Loaded {len(data)} cached calibration samples from {cache_path}")
67+
return data
68+
except Exception as e:
69+
print(f"Warning: Failed to load cache: {e}")
70+
return None
71+
72+
73+
def _save_cached_data(cache_path: Path, data: list[dict[str, Any]]) -> None:
74+
"""Save calibration data to cache."""
75+
try:
76+
cache_path.parent.mkdir(parents=True, exist_ok=True)
77+
with open(cache_path, "w") as f:
78+
json.dump(data, f)
79+
print(f"Saved calibration samples to cache: {cache_path}")
80+
except Exception as e:
81+
print(f"Warning: Failed to save cache: {e}")
82+
83+
3384
def _extract_tokenizer_from_model(model: nn.Module) -> str:
3485
"""Extract tokenizer name/path from model config.
3586
@@ -255,18 +306,31 @@ def calibrate_sparse_attention(
255306
calibration_data = None
256307

257308
if calibrate_prefill or calibrate_decode:
258-
builder = RulerDatasetBuilder(
259-
samples=calib_config.samples,
260-
max_seqlen=calib_config.max_seqlen,
261-
tokenizer_name_or_path=tokenizer,
262-
num_length_bins=calib_config.num_length_bins,
263-
max_length_filter=int(calib_config.max_seqlen * 1.5),
309+
# Try to load from cache first
310+
cache_path = _get_cache_path(
311+
tokenizer,
312+
calib_config.samples,
313+
calib_config.max_seqlen,
314+
cache_dir=calib_config.cache_dir,
264315
)
265-
calibration_data = builder.build_calibration_dataset()
266-
print(f"Generated {len(calibration_data)} calibration samples")
316+
calibration_data = _load_cached_data(cache_path)
317+
318+
# Generate if not cached
319+
if calibration_data is None:
320+
builder = RulerDatasetBuilder(
321+
samples=calib_config.samples,
322+
max_seqlen=calib_config.max_seqlen,
323+
tokenizer_name_or_path=tokenizer,
324+
num_length_bins=calib_config.num_length_bins,
325+
max_length_filter=int(calib_config.max_seqlen * 1.5),
326+
)
327+
calibration_data = builder.build_calibration_dataset()
328+
print(f"Generated {len(calibration_data)} calibration samples")
329+
330+
# Save to cache for future runs
331+
_save_cached_data(cache_path, calibration_data)
267332

268333
# Initialize results
269-
threshold_scale_factor: dict[str, float] = {}
270334
calibration_results: dict[str, Any] = {}
271335

272336
# Run prefill calibration if enabled
@@ -282,13 +346,11 @@ def calibrate_sparse_attention(
282346
)
283347

284348
prefill_calibrator = DynamicThresholdCalibrator(
285-
target_sparse_ratio=target_dict,
286349
threshold_trials=calib_config.threshold_trials,
287350
)
288351
prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill")
289352

290-
if "scale_factor" in prefill_result:
291-
threshold_scale_factor["prefill"] = prefill_result["scale_factor"]
353+
if "k" in prefill_result and "p" in prefill_result:
292354
calibration_results["prefill"] = prefill_result
293355
else:
294356
warnings.warn("Prefill calibration did not produce valid results")
@@ -306,38 +368,57 @@ def calibrate_sparse_attention(
306368
)
307369

308370
decode_calibrator = DynamicThresholdCalibrator(
309-
target_sparse_ratio=target_dict,
310371
threshold_trials=calib_config.threshold_trials,
311372
)
312373
decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode")
313374

314-
if "scale_factor" in decode_result:
315-
threshold_scale_factor["decode"] = decode_result["scale_factor"]
375+
if "k" in decode_result and "p" in decode_result:
316376
calibration_results["decode"] = decode_result
317377
else:
318378
warnings.warn("Decode calibration did not produce valid results")
319379

320380
# Check if any calibration succeeded
321-
if not threshold_scale_factor:
381+
if not calibration_results:
322382
warnings.warn("No calibration produced valid results")
323383
return {}
324384

325-
# Apply combined threshold_scale_factor dict to all modules
385+
# Extract k and p for each phase
386+
calibration_params: dict[str, dict[str, float]] = {}
387+
for phase in ["prefill", "decode"]:
388+
if phase in calibration_results:
389+
result = calibration_results[phase]
390+
calibration_params[phase] = {
391+
"k": result["k"],
392+
"p": result["p"],
393+
}
394+
395+
# Apply calibration params to all modules
326396
print("\n" + "=" * 60)
327397
print("APPLYING CALIBRATION RESULTS")
328398
print("=" * 60)
329-
print(f"Applying threshold_scale_factor to {len(sparse_modules)} modules:")
330-
for phase, scale_factor in threshold_scale_factor.items():
331-
print(f" {phase}: {scale_factor:.6f}")
399+
print(f"Applying calibration to {len(sparse_modules)} modules:")
400+
for phase, params in calibration_params.items():
401+
result = calibration_results[phase]
402+
print(f" {phase}:")
403+
print(f" Model: scale_factor = {params['k']:.4f} / (1 - sparsity)^{params['p']:.4f}")
404+
print(f" R-squared: {result['r_squared']:.6f}")
332405

333406
for module_name, module in sparse_modules:
334-
module._sparse_method_instance.threshold_scale_factor = threshold_scale_factor
407+
module._sparse_method_instance.calibration_params = calibration_params
408+
module._sparse_method_instance.target_sparse_ratio = target_dict
335409

336410
# Print final summary
337411
print("\nCalibration complete!")
412+
print(
413+
f"Target sparsity: prefill={target_dict.get('prefill', 0):.0%}, "
414+
f"decode={target_dict.get('decode', 0):.0%}"
415+
)
416+
print("\nTo change target sparsity at inference time, update:")
417+
print(" module._sparse_method_instance.target_sparse_ratio = {'prefill': X, 'decode': Y}")
338418
print_sparse_attention_summary(model)
339419

340420
return {
341-
"threshold_scale_factor": threshold_scale_factor,
421+
"calibration_params": calibration_params,
422+
"target_sparse_ratio": target_dict,
342423
"calibration_results": calibration_results,
343424
}

0 commit comments

Comments
 (0)