Skip to content

Commit 1ec931c

Browse files
authored
[2/3][Feat]: Offline DFlash training (#1343)
### What does this PR do? Type of change: new feature Part 2 of a 3-PR series splitting #1271: - **[1/3] #1296**: File reorg + deprecate `ParallelDraft` - **[2/3] this PR**: Offline DFlash training (depends on #1296) - **[3/3] #1297**: Extract `HFSpecDecMixin` Changes: - Add `dflash_offline` flag to `DFlashConfig` for training from pre-computed hidden states; deletes base model layers to save memory. - Add Pydantic validators on `DFlashConfig`: - `_derive_dflash_offline` — auto-derive `dflash_offline` from `data_args.offline_data_path` in validation context. Not user-configurable: any user-supplied value is overridden by the derived value. - `_resolve_mask_token_id` — auto-detect `dflash_mask_token_id` from `tokenizer.mask_token_id`. - `_check_mask_token_id` — fail fast if unset after resolution. - `HFDFlashModel.modify()`: select `num_orig_hidden_layers` when offline; pick `_base_model_lm_head` device when no base layers present; drop base-model `layers` module. - `HFDFlashModel.forward()`: add offline branch — consumes precomputed `base_model_outputs` via `DFlashBaseModelOutput.from_offline_dict`, and when `dflash_self_logit_distillation` is enabled with `base_model_logits` absent, recomputes logits from `base_model_hidden_states` via `_base_model_lm_head`. Raises a clear error from the non-training / `pseudo_speculative_generate` paths when `dflash_offline=True`, since base-model layers have been deleted. - `DFlashBaseModelOutput` dataclass in `modeling_dflash.py` (with `from_offline_dict` classmethod) to unify online/offline output shapes. `aux_hidden_states` is required in `from_offline_dict` so missing keys fail fast at the entry point rather than deeper in the forward. - `examples/speculative_decoding/main.py`: replace inline `mask_token_id` auto-detect with `DFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args})`. ### Silent bug fix — `add_generation_template` → `add_generation_prompt` The pre-refactor `compute_hidden_states_hf.py` passed `add_generation_template=False` to `tokenizer.apply_chat_template`. This kwarg does not exist on HF `apply_chat_template` and was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The new `tokenize_with_loss_mask` helper in `examples/speculative_decoding/collect_hidden_states/common.py` uses the correct `add_generation_prompt=False`. **This is a real behavior change** for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included. ### Testing - New tests: - `tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py` — CPU unit tests for convert path (online keeps base layers, offline deletes them; `num_orig_hidden_layers` drives `target_layer_ids` in offline mode) and `DFlashConfig._derive_dflash_offline` validator. - `TestDFlashOfflineForwardGPU` in `tests/gpu/torch/speculative/plugins/test_hf_dflash.py` — GPU forward smoke with precomputed `base_model_outputs`, plus the `dflash_self_logit_distillation` logit-recompute path. - training test: <img width="454" height="317" alt="image" src="https://github.com/user-attachments/assets/79b92790-4d15-4313-bb9b-f35665b012e6" /> <img width="456" height="310" alt="image" src="https://github.com/user-attachments/assets/4558559f-9c35-49ed-b36e-82fbc99eab23" /> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ — additive `dflash_offline` flag defaulting to `False`; validators fall through when context not provided. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ — see Testing section above. - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ ### TODO (follow-up) - [x] Update `examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py` to support DFlash offline data. Current scripts are Eagle-specific — they hardcode the `[2, N/2, N-3]` aux-layer selection and emit `{input_ids, hidden_states, aux_hidden_states}`. DFlash offline needs: - Aux layer indices driven by `build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)` (or a configurable list), not the Eagle triplet. - `base_model_hidden_states` key (last-layer hidden) so `DFlashBaseModelOutput.from_offline_dict` + the `dflash_self_logit_distillation` recompute path can consume it. - Optional `base_model_logits` dump so offline training can skip the self-distillation logit recomputation when logits are available. ### Additional Information Base branch is #1296 (file reorg). Retarget to `main` once #1296 merges. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Offline DFlash speculative-decoding training from precomputed base-model hidden states * Answer-only-loss training with persisted loss masks and optional chat-template support * Flexible auxiliary-layer selection via CLI and an exposed default aux-layer helper * Auto-derived offline flag in config and automatic memory optimization during offline conversion * **Documentation** * Updated guides for offline pipeline, aux-layer selection, and loss-masking options * **Tests** * New unit, GPU, and regression tests covering offline conversion, training, and config derivation <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 7c80d85 commit 1ec931c

15 files changed

Lines changed: 704 additions & 78 deletions

File tree

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
**New Features**
1919

2020
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
21+
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
2122

2223
0.44 (2026-05-xx)
2324
^^^^^^^^^^^^^^^^^
@@ -34,7 +35,7 @@ Changelog
3435
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
3536
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
3637
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
37-
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.kernels.quantization.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
38+
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
3839

3940
**Backward Breaking Changes**
4041

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared helpers for ``compute_hidden_states_*`` dump scripts.
17+
18+
Groups two concerns used by both the HF and vLLM dump entry points:
19+
20+
- Aux-layer selection via the ``--aux-layers`` flag (``"eagle"`` / ``"dflash"``
21+
/ explicit comma-separated list). Returned values are **0-based transformer
22+
layer IDs**; callers indexing into HuggingFace's ``outputs.hidden_states``
23+
tuple must add ``+1`` because ``hidden_states[0]`` is the embedding output.
24+
- Answer-only-loss support: registering ``--answer-only-loss`` /
25+
``--chat-template`` flags, loading a chat template file, verifying the
26+
template contains ``{% generation %}`` tags, and computing per-conversation
27+
``loss_mask`` from the tokenizer's ``assistant_masks``.
28+
"""
29+
30+
import argparse
31+
from pathlib import Path
32+
33+
import torch
34+
35+
_DFLASH_DEFAULT_NUM_DRAFT_LAYERS = 5
36+
37+
38+
def add_aux_layers_args(parser: argparse.ArgumentParser) -> None:
39+
"""Register the ``--aux-layers`` flag on ``parser``."""
40+
parser.add_argument(
41+
"--aux-layers",
42+
type=str,
43+
default="eagle",
44+
help=(
45+
"Aux layer indices to capture. One of: "
46+
"'eagle' (EAGLE-3 default from modelopt), "
47+
f"'dflash' ({_DFLASH_DEFAULT_NUM_DRAFT_LAYERS}-layer DFlash default from modelopt), "
48+
"or a comma-separated list like '2,5,8' to override. Default: eagle."
49+
),
50+
)
51+
52+
53+
def resolve_aux_layers(args: argparse.Namespace, num_hidden_layers: int) -> list[int]:
54+
"""Resolve ``args.aux_layers`` to a sorted, de-duped list of 0-based layer IDs."""
55+
value = args.aux_layers.strip().lower()
56+
if value == "eagle":
57+
from modelopt.torch.speculative.plugins.hf_eagle import default_eagle_aux_layer_ids
58+
59+
return default_eagle_aux_layer_ids(num_hidden_layers)
60+
if value == "dflash":
61+
from modelopt.torch.speculative.plugins.modeling_dflash import build_target_layer_ids
62+
63+
return sorted(
64+
set(build_target_layer_ids(num_hidden_layers, _DFLASH_DEFAULT_NUM_DRAFT_LAYERS))
65+
)
66+
try:
67+
indices = [int(tok) for tok in args.aux_layers.split(",") if tok.strip()]
68+
except ValueError as e:
69+
raise ValueError(
70+
f"--aux-layers must be 'eagle', 'dflash', or a comma-separated int list, "
71+
f"got: {args.aux_layers!r}"
72+
) from e
73+
if not indices:
74+
raise ValueError(f"--aux-layers int list is empty: {args.aux_layers!r}")
75+
for i in indices:
76+
if not 0 <= i < num_hidden_layers:
77+
raise ValueError(f"--aux-layers index {i} out of range [0, {num_hidden_layers})")
78+
return sorted(set(indices))
79+
80+
81+
def add_answer_only_loss_args(parser: argparse.ArgumentParser) -> None:
82+
"""Register ``--answer-only-loss`` and ``--chat-template`` flags on ``parser``."""
83+
parser.add_argument(
84+
"--answer-only-loss",
85+
action="store_true",
86+
help=(
87+
"If set, compute an assistant-token mask via the tokenizer's "
88+
"{% generation %} tags and save it as 'loss_mask' in each .pt file. "
89+
"Downstream offline training uses this to apply loss only on "
90+
"assistant-produced tokens."
91+
),
92+
)
93+
parser.add_argument(
94+
"--chat-template",
95+
type=Path,
96+
default=None,
97+
help=(
98+
"Path to a Jinja chat template file that overrides tokenizer.chat_template. "
99+
"Required with --answer-only-loss if the model's default template lacks "
100+
"{% generation %} / {% endgeneration %} tags."
101+
),
102+
)
103+
104+
105+
def load_chat_template(path: Path | None) -> str | None:
106+
"""Read a Jinja chat template from ``path``, or return ``None`` if not provided."""
107+
if path is None:
108+
return None
109+
with open(path) as f:
110+
return f.read()
111+
112+
113+
def verify_generation_tags(chat_template: str | None) -> None:
114+
"""Raise if ``chat_template`` lacks ``{% generation %}`` / ``{% endgeneration %}`` tags.
115+
116+
These tags are required for ``apply_chat_template(..., return_assistant_tokens_mask=True)``
117+
to return the assistant-token mask needed for answer-only-loss training.
118+
"""
119+
if chat_template and "generation" in chat_template and "endgeneration" in chat_template:
120+
return
121+
raise ValueError(
122+
"--answer-only-loss requires {% generation %} / {% endgeneration %} tags in the "
123+
"chat template, but the current template does not have them.\n\n"
124+
"To fix, pass --chat-template pointing to a template with generation tags:\n"
125+
" 1. Copy the model's chat_template from tokenizer_config.json\n"
126+
" 2. Wrap assistant content with {% generation %} / {% endgeneration %}\n"
127+
"See https://huggingface.co/docs/transformers/en/chat_templating"
128+
"#train-on-completions-only for details."
129+
)
130+
131+
132+
def tokenize_with_loss_mask(
133+
tokenizer,
134+
conversations: list,
135+
answer_only_loss: bool,
136+
) -> tuple[torch.Tensor, torch.Tensor]:
137+
"""Tokenize one conversation and derive its loss mask from the same call.
138+
139+
Uses a single ``apply_chat_template`` invocation so ``input_ids`` and
140+
``loss_mask`` are guaranteed to come from the same tokenization — this
141+
eliminates the risk of argument drift between two separate calls.
142+
143+
Returns:
144+
input_ids: ``LongTensor`` of shape ``(1, seq_len)``.
145+
loss_mask: ``LongTensor`` of shape ``(seq_len,)``. All-ones when
146+
``answer_only_loss=False``; the assistant-token mask from the
147+
tokenizer when ``answer_only_loss=True`` (requires ``{% generation %}``
148+
tags in the chat template — verify beforehand).
149+
"""
150+
out = tokenizer.apply_chat_template(
151+
conversations,
152+
return_tensors="pt",
153+
return_dict=True,
154+
return_assistant_tokens_mask=answer_only_loss,
155+
add_generation_prompt=False,
156+
)
157+
input_ids = out["input_ids"]
158+
seq_len = input_ids.shape[-1]
159+
if answer_only_loss:
160+
mask = out["assistant_masks"]
161+
if not isinstance(mask, torch.Tensor):
162+
mask = torch.tensor(mask, dtype=torch.long)
163+
loss_mask = mask.squeeze(0).to(torch.long)
164+
if loss_mask.shape[0] != seq_len:
165+
raise RuntimeError(
166+
f"assistant_masks length {loss_mask.shape[0]} does not match "
167+
f"input_ids length {seq_len}"
168+
)
169+
else:
170+
loss_mask = torch.ones(seq_len, dtype=torch.long)
171+
return input_ids, loss_mask

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
from pathlib import Path
2121

2222
import torch
23+
from common import (
24+
add_answer_only_loss_args,
25+
add_aux_layers_args,
26+
load_chat_template,
27+
resolve_aux_layers,
28+
tokenize_with_loss_mask,
29+
verify_generation_tags,
30+
)
2331
from datasets import load_dataset
2432
from tqdm import tqdm as tqdm
2533
from transformers import AutoModel, AutoTokenizer
@@ -90,6 +98,8 @@ def parse_args() -> argparse.Namespace:
9098
action="store_true",
9199
help="Set trust_remote_code for Huggingface models and tokenizers",
92100
)
101+
add_aux_layers_args(parser)
102+
add_answer_only_loss_args(parser)
93103

94104
return parser.parse_args()
95105

@@ -138,12 +148,20 @@ def keep_conversation(entry):
138148
args.model, dtype="auto", device_map="auto", trust_remote_code=args.trust_remote_code
139149
)
140150
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)
151+
if num_hidden_layers is None:
152+
raise ValueError(f"model.config has no 'num_hidden_layers' attribute: {model.config}")
153+
selected_layer_ids = resolve_aux_layers(args, num_hidden_layers)
141154

142155
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
143156
if tokenizer.pad_token is None:
144157
tokenizer.pad_token = tokenizer.eos_token
158+
override_template = load_chat_template(args.chat_template)
159+
if override_template is not None:
160+
tokenizer.chat_template = override_template
145161
if tokenizer.chat_template is not None:
146162
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
163+
if args.answer_only_loss:
164+
verify_generation_tags(tokenizer.chat_template)
147165

148166
output_dir = args.output_dir
149167
output_dir.mkdir(parents=True, exist_ok=True)
@@ -152,29 +170,21 @@ def keep_conversation(entry):
152170
num_success = 0
153171
pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations")
154172

155-
async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Tensor):
173+
async def dump_hidden_states(
174+
idx: int,
175+
conversation_id: int,
176+
input_ids: torch.Tensor,
177+
loss_mask: torch.Tensor,
178+
):
156179
nonlocal num_success
157-
nonlocal num_hidden_layers
158180

159181
# Get hidden states
160182
with torch.inference_mode():
161183
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
162-
if num_hidden_layers is None:
163-
num_hidden_layers = len(outputs.hidden_states) - 1
164-
else:
165-
assert num_hidden_layers + 1 == len(outputs.hidden_states), (
166-
f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}."
167-
)
168-
# Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
184+
# outputs.hidden_states[0] is the embedding output; layer k output is at index k+1.
169185
hidden_states = outputs.hidden_states
170-
selected_layer_indices = [
171-
2,
172-
max(0, num_hidden_layers // 2),
173-
max(1, num_hidden_layers - 3),
174-
]
175-
selected_layer_indices = sorted(set(selected_layer_indices))
176186
aux_hidden_states = torch.cat(
177-
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
187+
[hidden_states[lid + 1].squeeze(0).cpu() for lid in selected_layer_ids], dim=-1
178188
)
179189
output_hidden_states = hidden_states[-1].squeeze(0).cpu()
180190
output_file = output_dir / f"{conversation_id}.pt"
@@ -185,6 +195,7 @@ async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Te
185195
"input_ids": input_ids.squeeze(0).cpu(),
186196
"hidden_states": output_hidden_states,
187197
"aux_hidden_states": aux_hidden_states,
198+
"loss_mask": loss_mask,
188199
"conversation_id": conversation_id,
189200
},
190201
f,
@@ -206,19 +217,17 @@ async def submit_generates():
206217
num_invalid += 1
207218
continue
208219

209-
# Tokenize and check length
210-
# return_dict=True ensures BatchEncoding is returned on all transformers
211-
# versions: in <5.0 the default is False (returns raw tensor), in 5.0+
212-
# the default changed to True (returns BatchEncoding).
213-
input_ids = tokenizer.apply_chat_template(
214-
conversations, return_tensors="pt", return_dict=True, add_generation_template=False
215-
)["input_ids"]
220+
# Single apply_chat_template call produces both input_ids and loss_mask,
221+
# guaranteeing they come from the same tokenization.
222+
input_ids, loss_mask = tokenize_with_loss_mask(
223+
tokenizer, conversations, args.answer_only_loss
224+
)
216225
num_input_tokens = input_ids.shape[1]
217226
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
218227
num_skipped_too_long += 1
219228
continue
220229

221-
tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
230+
tasks.append(dump_hidden_states(idx, conversation_id, input_ids, loss_mask))
222231
# Increment only for valid conversations to match dump file index
223232
idx += 1
224233
await asyncio.gather(*tasks)

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pathlib import Path
2424

2525
import torch
26+
from common import add_aux_layers_args, resolve_aux_layers
2627
from datasets import load_dataset
2728
from tensorrt_llm import LLM, SamplingParams
2829
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
@@ -122,6 +123,7 @@ def parse_args() -> argparse.Namespace:
122123
default=None,
123124
help="""moe_cluster_parallel_size for TRTLLM.""",
124125
)
126+
add_aux_layers_args(parser)
125127

126128
return parser.parse_args()
127129

@@ -194,7 +196,7 @@ def keep_conversation(entry):
194196
"output_directory": str(args.output_dir),
195197
"write_interval": 1,
196198
"file_prefix": f"dp_{args.dp_rank}",
197-
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
199+
"eagle3_layers_to_capture": set(resolve_aux_layers(args, num_hidden_layers)),
198200
}
199201
sampling_params = SamplingParams(max_tokens=32, temperature=0)
200202

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def make_speculative_data_module(
108108
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
109109
if data_args.sample_size > 0:
110110
dumped_files = dumped_files[: data_args.sample_size]
111-
train_dataset = OfflineSupervisedDataset(dumped_files)
111+
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
112112
data_collator = EagleOfflineDataCollator(train_len=train_len)
113113

114114
return {

examples/speculative_decoding/main.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
import modelopt.torch.opt as mto
5151
import modelopt.torch.speculative as mtsp
52-
from modelopt.torch.speculative.config import EagleConfig
52+
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
5353
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5454
from modelopt.torch.utils import print_rank_0
5555

@@ -318,18 +318,9 @@ def train():
318318
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
319319
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
320320
elif training_args.mode == "dflash":
321-
# Auto-detect mask_token_id from tokenizer if not set
322-
if not dflash_cfg.get("dflash_mask_token_id"):
323-
if tokenizer.mask_token_id is not None:
324-
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
325-
print_rank_0(
326-
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
327-
)
328-
else:
329-
raise ValueError(
330-
"mask_token_id not found in tokenizer and not set in config. "
331-
"Set dflash.dflash_mask_token_id in the training YAML."
332-
)
321+
dflash_cfg = DFlashConfig.model_validate(
322+
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
323+
).model_dump()
333324
mtsp.convert(model, [("dflash", dflash_cfg)])
334325
else:
335326
raise Exception(f"{training_args.mode} is not supported!")

0 commit comments

Comments
 (0)