Skip to content

Commit cd0d185

Browse files
authored
[OMNIML-2852] [2/n] Add Core Sparse Attention Infrastructure (#527)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> New feature **Overview:** ? This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs. Key Features: - Skip softmax support - Sparse attention config - Extensible method registry for future sparse attention algorithms - HuggingFace Transformers integration - Phase-aware thresholds (separate prefill/decode) [Design doc](https://docs.google.com/document/d/1OgmTAKkoD4ZSWYXel-FeaQqmI5PtyNhQ4dEuhGiZAQQ/edit?tab=t.0#heading=h.dyp44woziy9x) ## Usage <!-- You can potentially add a usage example below. --> ```python import torch import modelopt.torch.sparsity.attention_sparsity as mts from transformers import AutoModelForCausalLM # Load model (must use eager attention for softmax patching) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", attn_implementation="eager", # Required! torch_dtype=torch.bfloat16, ) # Use pre-defined configuration from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ### Unit Test ```bash pytest tests/unit/torch/sparsity/attention_sparsity -v pytest tests/gpu/torch/sparsity/attention_sparsity -v pytest tests/examples/llm_sparsity/attention_sparsity -v ``` ALL PASSED. ### Accuracy Benchmark: MMLU Model: Qwen/Qwen3-4B Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT | | MMLU | |----------------------|-------| | BF16 | 69.96 | | SKIP_SOFTMAX_DEFAULT | 69.86 | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent f731379 commit cd0d185

31 files changed

Lines changed: 3005 additions & 9 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ repos:
9292
examples/llm_eval/mmlu.py|
9393
examples/llm_eval/modeling.py|
9494
examples/llm_qat/main.py|
95-
examples/llm_sparsity/finetune.py|
95+
examples/llm_sparsity/weight_sparsity/finetune.py|
9696
examples/speculative_decoding/main.py|
9797
examples/speculative_decoding/medusa_utils.py|
9898
examples/speculative_decoding/server_generate.py|
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Example script for applying sparse attention to HuggingFace models."""
18+
19+
import argparse
20+
import random
21+
from pathlib import Path
22+
23+
import numpy as np
24+
import torch
25+
from datasets import load_dataset
26+
from transformers import AutoModelForCausalLM, AutoTokenizer
27+
28+
import modelopt.torch.opt as mto
29+
import modelopt.torch.sparsity.attention_sparsity as mtsa
30+
from modelopt.torch.export import export_hf_checkpoint
31+
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
32+
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
33+
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
34+
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
35+
36+
RAND_SEED = 1234
37+
38+
# Enable HuggingFace checkpointing support
39+
mto.enable_huggingface_checkpointing()
40+
41+
# You can define custom configurations or use the default
42+
SPARSE_ATTN_CFG_CHOICES = {
43+
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
44+
}
45+
46+
47+
def get_narrativeqa_samples(num_samples=3):
48+
"""Load samples from NarrativeQA dataset for testing.
49+
50+
Args:
51+
num_samples: Number of samples to generate
52+
53+
Raises:
54+
RuntimeError: If dataset loading fails
55+
ValueError: If no valid samples could be loaded
56+
"""
57+
# Load NarrativeQA dataset with retry logic
58+
try:
59+
dataset = load_dataset("narrativeqa", split="test", streaming=True)
60+
except Exception as e:
61+
raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}")
62+
63+
samples = []
64+
for i, item in enumerate(dataset):
65+
if i >= num_samples:
66+
break
67+
68+
# Combine document context and question
69+
context = item.get("document", {}).get("text", "")
70+
question = item.get("question", {}).get("text", "")
71+
72+
if context and question:
73+
# Use the full context as-is
74+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
75+
samples.append(prompt)
76+
77+
if not samples:
78+
raise ValueError("Could not load NarrativeQA samples")
79+
80+
print(f"Loaded {len(samples)} NarrativeQA samples")
81+
return samples
82+
83+
84+
def truncate_text(text: str, tokenizer, max_length: int):
85+
"""Truncate text from the middle to preserve beginning and end.
86+
87+
Args:
88+
text: Input text to truncate
89+
tokenizer: Tokenizer to use for encoding
90+
max_length: Maximum number of tokens
91+
92+
Returns:
93+
Truncated text that fits within max_length tokens
94+
"""
95+
# First tokenize to see if truncation is needed
96+
tokens = tokenizer.encode(text, add_special_tokens=True)
97+
98+
if len(tokens) <= max_length:
99+
return text
100+
101+
# Need to truncate - preserve beginning and end
102+
# Calculate actual special tokens used
103+
dummy_tokens = tokenizer.encode("", add_special_tokens=True)
104+
special_token_count = len(dummy_tokens)
105+
available_tokens = max_length - special_token_count
106+
107+
# Split tokens roughly in half for beginning and end
108+
begin_tokens = available_tokens // 2
109+
end_tokens = available_tokens - begin_tokens
110+
111+
# Decode beginning and end parts
112+
begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True)
113+
end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)
114+
115+
# Combine with ellipsis marker
116+
return begin_text + " [...] " + end_text
117+
118+
119+
def verify_outputs(model, tokenizer, args):
120+
"""Compare outputs between baseline and sparse attention models."""
121+
# Update seq_len to match calibration max_seqlen if calibration was used
122+
base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {})
123+
if "calibration" in base_config and "max_seqlen" in base_config["calibration"]:
124+
calib_max_seqlen = base_config["calibration"]["max_seqlen"]
125+
if args.seq_len != calib_max_seqlen:
126+
print(
127+
f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} "
128+
f"to match calibration config"
129+
)
130+
args.seq_len = calib_max_seqlen
131+
132+
# Load and prepare a single test prompt
133+
print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)")
134+
prompts = get_narrativeqa_samples(num_samples=1)
135+
prompt = prompts[0]
136+
137+
# Prepare inputs
138+
truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len)
139+
display_prompt = (
140+
truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt
141+
)
142+
143+
inputs = tokenizer(
144+
truncated_prompt,
145+
return_tensors="pt",
146+
max_length=args.seq_len,
147+
truncation=True,
148+
padding=False,
149+
)
150+
if torch.cuda.is_available():
151+
inputs = {k: v.cuda() for k, v in inputs.items()}
152+
153+
print("\n" + "=" * 60)
154+
print("BASELINE vs SPARSE ATTENTION COMPARISON")
155+
print("=" * 60)
156+
print(f"\nTest prompt: {display_prompt}")
157+
print(f"Input tokens: {inputs['input_ids'].shape[1]}")
158+
159+
# Helper function to generate text
160+
def generate_text(model, inputs, args, tokenizer):
161+
with torch.no_grad():
162+
outputs = model.generate(
163+
**inputs,
164+
max_new_tokens=args.max_new_tokens,
165+
do_sample=args.do_sample,
166+
temperature=args.temperature if args.do_sample else 1.0,
167+
pad_token_id=tokenizer.pad_token_id,
168+
)
169+
input_length = inputs["input_ids"].shape[1]
170+
generated_ids = outputs[0][input_length:]
171+
return tokenizer.decode(generated_ids, skip_special_tokens=True)
172+
173+
# Find all sparse attention modules
174+
sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
175+
176+
# Generate baseline by temporarily disabling sparse attention
177+
print("\n" + "-" * 60)
178+
print("Generating baseline (sparse attention disabled)...")
179+
for module in sparse_modules:
180+
module.disable()
181+
baseline_text = generate_text(model, inputs, args, tokenizer)
182+
183+
# Generate with sparse attention enabled
184+
print("\nGenerating with sparse attention (calibrated thresholds)...")
185+
for module in sparse_modules:
186+
module.enable()
187+
sparse_text = generate_text(model, inputs, args, tokenizer)
188+
189+
# Display comparison
190+
print("\n" + "-" * 60)
191+
print("RESULTS:")
192+
baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text
193+
sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text
194+
195+
print(f"\nBaseline: {baseline_display}")
196+
print(f"With Sparse: {sparse_display}")
197+
198+
if baseline_text == sparse_text:
199+
print("\nOutputs are identical")
200+
else:
201+
print("\nOutputs differ")
202+
203+
204+
def sparsify_model(model, args):
205+
"""Apply sparse attention to the model with optional calibration."""
206+
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
207+
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
208+
209+
# Create modified config with selected backend
210+
modified_sparse_cfg = {}
211+
for pattern, cfg in base_config["sparse_cfg"].items():
212+
modified_cfg = cfg.copy()
213+
modified_cfg["backend"] = args.backend
214+
modified_sparse_cfg[pattern] = modified_cfg
215+
216+
# Create new config with modified settings
217+
sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg)
218+
219+
# Sparsify the model
220+
model = mtsa.sparsify(model, config=sparse_config)
221+
222+
print("Sparse attention applied successfully!")
223+
224+
return model
225+
226+
227+
def main(args):
228+
"""Main function to run the selected mode."""
229+
if not torch.cuda.is_available():
230+
raise OSError("GPU is required for inference.")
231+
232+
random.seed(RAND_SEED)
233+
np.random.seed(RAND_SEED)
234+
launch_memory_monitor()
235+
236+
print(f"Loading model: {args.pyt_ckpt_path}")
237+
238+
# Load model and tokenizer
239+
# Note: attn_implementation="eager" is required for calibration to work properly
240+
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
241+
model = AutoModelForCausalLM.from_pretrained(
242+
args.pyt_ckpt_path,
243+
attn_implementation="eager",
244+
torch_dtype=torch.bfloat16,
245+
)
246+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
247+
248+
# Set pad token if not set
249+
if tokenizer.pad_token is None:
250+
tokenizer.pad_token = tokenizer.eos_token
251+
252+
# Move model to GPU if available
253+
if torch.cuda.is_available():
254+
model = model.cuda()
255+
print("Model moved to CUDA")
256+
257+
# Apply sparse attention to the model (with calibration if configured)
258+
model = sparsify_model(model, args)
259+
260+
# Verify outputs if requested (compares baseline vs calibrated sparse model)
261+
if args.verify_output:
262+
verify_outputs(model, tokenizer, args)
263+
264+
# Export if requested
265+
if args.export_dir:
266+
print(f"\nExporting model to: {args.export_dir}")
267+
export_dir = Path(args.export_dir)
268+
export_dir.mkdir(parents=True, exist_ok=True)
269+
270+
with torch.inference_mode():
271+
export_hf_checkpoint(model, export_dir=export_dir)
272+
273+
tokenizer.save_pretrained(export_dir)
274+
print(f"Model exported successfully to: {export_dir}")
275+
276+
277+
if __name__ == "__main__":
278+
parser = argparse.ArgumentParser(description=__doc__)
279+
280+
# Model arguments
281+
parser.add_argument(
282+
"--pyt_ckpt_path",
283+
type=str,
284+
required=True,
285+
help="Specify where the PyTorch checkpoint path is",
286+
)
287+
parser.add_argument(
288+
"--sparse_attn",
289+
type=str,
290+
default="skip_softmax",
291+
choices=list(SPARSE_ATTN_CFG_CHOICES.keys()),
292+
help="Sparse attention configuration to apply.",
293+
)
294+
parser.add_argument(
295+
"--backend",
296+
type=str,
297+
default="pytorch",
298+
choices=["pytorch"],
299+
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
300+
)
301+
302+
# Sequence length arguments
303+
parser.add_argument(
304+
"--seq_len",
305+
type=int,
306+
default=2048,
307+
help="Maximum sequence length for input prompts (will be truncated if longer)",
308+
)
309+
parser.add_argument(
310+
"--num_samples",
311+
type=int,
312+
default=3,
313+
help="Number of samples to use from NarrativeQA dataset",
314+
)
315+
316+
# Generation arguments
317+
parser.add_argument(
318+
"--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate"
319+
)
320+
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
321+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")
322+
323+
# Operation arguments
324+
parser.add_argument(
325+
"--verify_output",
326+
action="store_true",
327+
help="Verify that sparse attention outputs match baseline",
328+
)
329+
parser.add_argument(
330+
"--export_dir",
331+
type=str,
332+
default=None,
333+
help="Directory to export the model with sparse attention applied",
334+
)
335+
336+
args = parser.parse_args()
337+
main(args)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

examples/llm_sparsity/finetune.py renamed to examples/llm_sparsity/weight_sparsity/finetune.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py
2-
3-
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
65
# you may not use this file except in compliance with the License.
@@ -14,8 +13,9 @@
1413
# See the License for the specific language governing permissions and
1514
# limitations under the License.
1615

17-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
18-
# SPDX-License-Identifier: Apache-2.0
16+
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py
17+
18+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
1919
#
2020
# Licensed under the Apache License, Version 2.0 (the "License");
2121
# you may not use this file except in compliance with the License.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)