Skip to content

Commit 179e8dd

Browse files
committed
Address review feedbacks
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 031ae91 commit 179e8dd

24 files changed

Lines changed: 683 additions & 1253 deletions

File tree

examples/llm_sparsity/attention_sparsity/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Attention Sparsity for HuggingFace Models
22

3-
In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
3+
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
44

55
## Getting Started
66

@@ -63,7 +63,7 @@ pip install nvidia-modelopt[hf]
6363
If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first:
6464

6565
```bash
66-
bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh
66+
bash ./download_ruler_data.sh
6767
```
6868

6969
This downloads the Paul Graham essays dataset used for generating calibration samples.
@@ -75,7 +75,7 @@ This downloads the Paul Graham essays dataset used for generating calibration sa
7575
Apply sparse attention with a fixed threshold:
7676

7777
```bash
78-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
78+
python hf_sa.py \
7979
--pyt_ckpt_path Qwen/Qwen3-8B \
8080
--sparse_attn skip_softmax
8181
```
@@ -85,7 +85,7 @@ python examples/llm_sparsity/attention_sparsity/hf_sa.py \
8585
Apply sparse attention with calibrated thresholds for optimal sparsity:
8686

8787
```bash
88-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
88+
python hf_sa.py \
8989
--pyt_ckpt_path Qwen/Qwen3-8B \
9090
--sparse_attn skip_softmax_calib
9191
```
@@ -121,7 +121,7 @@ The script automatically compares outputs before and after applying sparse atten
121121
Export the sparsified model to a HuggingFace checkpoint:
122122

123123
```bash
124-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
124+
python hf_sa.py \
125125
--pyt_ckpt_path Qwen/Qwen3-8B \
126126
--sparse_attn skip_softmax_calib \
127127
--export_dir ./exported_sparse_model
@@ -161,5 +161,5 @@ model = mtsa.sparsify(model, config=custom_config)
161161

162162
## References
163163

164-
- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/)
164+
- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
165165
- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER)

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def main(args):
171171
print(f"\nApplying sparse attention: {args.sparse_attn}")
172172
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
173173

174-
# Override target_sparse_ratio if provided via CLI
174+
# Override calibration options if provided via CLI
175175
if args.target_sparse_ratio is not None:
176176
sparse_config = copy.deepcopy(sparse_config)
177177
sparse_cfg = sparse_config.get("sparse_cfg", {})

modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .calibrate import calibrate_sparse_attention
1919
from .calibrator import DynamicThresholdCalibrator
20-
from .dataset import RulerDatasetBuilder
20+
from .ruler_dataset import RulerDatasetBuilder
2121

2222
__all__ = [
2323
"DynamicThresholdCalibrator",

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 23 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -15,70 +15,29 @@
1515

1616
"""Calibration functions for sparse attention."""
1717

18-
import hashlib
19-
import json
2018
import warnings
2119
from collections.abc import Callable
22-
from pathlib import Path
2320
from typing import Any
2421

2522
import torch
2623
import torch.nn as nn
2724
from transformers import AutoTokenizer
2825

26+
from modelopt.torch.utils import get_module_device
27+
2928
from ..config import CalibrationConfig
3029
from ..conversion import print_sparse_attention_summary
31-
from ..sparse_attention import SparseAttentionModule
30+
from ..utils import get_named_sparse_attention_modules
3231
from .calibrator import DynamicThresholdCalibrator
33-
from .dataset import RulerDatasetBuilder
34-
32+
from .ruler_dataset import RulerDatasetBuilder
3533

36-
def _get_cache_path(
37-
tokenizer_path: str, samples: int, max_seqlen: int, cache_dir: str | None = None
38-
) -> Path:
39-
"""Generate cache file path based on calibration parameters.
4034

41-
Args:
42-
tokenizer_path: Path to tokenizer (used in hash)
43-
samples: Number of calibration samples
44-
max_seqlen: Maximum sequence length
45-
cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/
46-
"""
47-
# Create a hash of the parameters for the cache filename
48-
key = f"{tokenizer_path}_{samples}_{max_seqlen}"
49-
hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12]
50-
filename = f"ruler_cache_{samples}s_{max_seqlen}l_{hash_str}.json"
51-
52-
if cache_dir:
53-
base_dir = Path(cache_dir)
54-
else:
55-
base_dir = Path.home() / ".cache" / "modelopt" / "sparse_attention"
56-
57-
return base_dir / filename
58-
59-
60-
def _load_cached_data(cache_path: Path) -> list[dict[str, Any]] | None:
61-
"""Load calibration data from cache if it exists."""
62-
if cache_path.exists():
63-
try:
64-
with open(cache_path) as f:
65-
data = json.load(f)
66-
print(f"Loaded {len(data)} cached calibration samples from {cache_path}")
67-
return data
68-
except Exception as e:
69-
print(f"Warning: Failed to load cache: {e}")
70-
return None
71-
72-
73-
def _save_cached_data(cache_path: Path, data: list[dict[str, Any]]) -> None:
74-
"""Save calibration data to cache."""
75-
try:
76-
cache_path.parent.mkdir(parents=True, exist_ok=True)
77-
with open(cache_path, "w") as f:
78-
json.dump(data, f)
79-
print(f"Saved calibration samples to cache: {cache_path}")
80-
except Exception as e:
81-
print(f"Warning: Failed to save cache: {e}")
35+
def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer":
36+
"""Load tokenizer and ensure pad_token is set."""
37+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
38+
if not tokenizer.pad_token:
39+
tokenizer.pad_token = tokenizer.eos_token
40+
return tokenizer
8241

8342

8443
def _extract_tokenizer_from_model(model: nn.Module) -> str:
@@ -147,12 +106,10 @@ def create_calibration_forward_loop(
147106
Returns:
148107
Forward loop function that takes model as argument
149108
"""
150-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
151-
if not tokenizer.pad_token:
152-
tokenizer.pad_token = tokenizer.eos_token
109+
tokenizer = _load_tokenizer(tokenizer_name_or_path)
153110

154111
def forward_loop(model: nn.Module) -> None:
155-
device = next(model.parameters()).device
112+
device = get_module_device(model)
156113

157114
for sample in calibration_data:
158115
inputs = tokenizer(
@@ -205,12 +162,10 @@ def create_decode_calibration_forward_loop(
205162
Returns:
206163
Forward loop function that takes model as argument
207164
"""
208-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
209-
if not tokenizer.pad_token:
210-
tokenizer.pad_token = tokenizer.eos_token
165+
tokenizer = _load_tokenizer(tokenizer_name_or_path)
211166

212167
def forward_loop(model: nn.Module) -> None:
213-
device = next(model.parameters()).device
168+
device = get_module_device(model)
214169

215170
for sample in calibration_data:
216171
inputs = tokenizer(
@@ -291,9 +246,7 @@ def calibrate_sparse_attention(
291246
return {}
292247

293248
# Get sparse attention modules
294-
sparse_modules = [
295-
(name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule)
296-
]
249+
sparse_modules = get_named_sparse_attention_modules(model)
297250

298251
if not sparse_modules:
299252
print("No sparse attention modules found for calibration")
@@ -306,29 +259,16 @@ def calibrate_sparse_attention(
306259
calibration_data = None
307260

308261
if calibrate_prefill or calibrate_decode:
309-
# Try to load from cache first
310-
cache_path = _get_cache_path(
311-
tokenizer,
312-
calib_config.samples,
313-
calib_config.max_seqlen,
262+
builder = RulerDatasetBuilder(
263+
samples=calib_config.samples,
264+
max_seqlen=calib_config.max_seqlen,
265+
tokenizer_name_or_path=tokenizer,
266+
num_length_bins=calib_config.num_length_bins,
267+
max_length_filter=int(calib_config.max_seqlen * 1.5),
314268
cache_dir=calib_config.cache_dir,
269+
data_dir=calib_config.data_dir,
315270
)
316-
calibration_data = _load_cached_data(cache_path)
317-
318-
# Generate if not cached
319-
if calibration_data is None:
320-
builder = RulerDatasetBuilder(
321-
samples=calib_config.samples,
322-
max_seqlen=calib_config.max_seqlen,
323-
tokenizer_name_or_path=tokenizer,
324-
num_length_bins=calib_config.num_length_bins,
325-
max_length_filter=int(calib_config.max_seqlen * 1.5),
326-
)
327-
calibration_data = builder.build_calibration_dataset()
328-
print(f"Generated {len(calibration_data)} calibration samples")
329-
330-
# Save to cache for future runs
331-
_save_cached_data(cache_path, calibration_data)
271+
calibration_data = builder.build_calibration_dataset()
332272

333273
# Initialize results
334274
calibration_results: dict[str, Any] = {}

modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from scipy.optimize import curve_fit
2727
from tqdm import tqdm
2828

29-
from ..sparse_attention import SparseAttentionModule
3029
from ..stats_manager import SparseAttentionStatsManager
30+
from ..utils import get_sparse_attention_modules
3131

3232

3333
class DynamicThresholdCalibrator:
@@ -113,7 +113,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
113113
Dict with calibration results including a, b, r_squared, and num_data_points
114114
"""
115115
# Extract attention modules
116-
attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
116+
attention_modules = get_sparse_attention_modules(model)
117117

118118
if not attention_modules:
119119
raise ValueError("No sparse attention modules found for calibration")

0 commit comments

Comments
 (0)