Skip to content

Commit 5887410

Browse files
fix: EAGLE mix_hidden_states in-place op crash (#1088) (#1104)
### Type of change - [x] Bug fix (non-breaking change which fixes an issue) ### Description Fixes #1088 — `RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: IndexPutBackward0` when training with `eagle_mix_hidden_states=True`. **Root cause:** In `HFEagleModel._eagle_training_forward`, the indexed assignment at line 991–994 modifies `eagle_input_hiddens` in-place while it is still part of the autograd computation graph. **Fix:** Clone the tensor before the in-place assignment. This is the same pattern already used in the Megatron backend at `megatron_eagle.py:1201-1202`: ```python # Clone to avoid inplace modification of view created in no_grad mode eagle_module_input_hidden_states = eagle_module_input_hidden_states.clone() ``` The HF backend was missing this clone. ### Usage ```python config["eagle_mix_hidden_states"] = True config["eagle_ttt_steps"] = 2 mtsp.convert(model, mode=[("eagle", config)]) model.train() outputs = model(input_ids=input_ids, labels=labels) outputs.loss.backward() # no longer crashes ``` ### Testing Added `test_eagle_mix_hidden_states_backward` parametrized over `eagle_ttt_steps` [1, 2] that: - Converts a tiny LLaMA to EAGLE with `eagle_mix_hidden_states=True` - Runs forward + backward pass - Asserts loss is not None and gradients flow to `eagle_module` ``` pytest tests/unit/torch/speculative/plugins/test_hf_speculative.py::test_eagle_mix_hidden_states_backward -v ``` ### Checklist - [x] I have read the [contributor guidelines](CONTRIBUTING.md) and signed my commits - [x] I have followed the [security best practices](SECURITY.md) - [x] This change is backward compatible - [x] I have followed third-party code and dependency guidelines - [x] I have added tests that prove my fix is effective <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed gradient computation issue in speculative decoding during model training to ensure proper autograd behavior. * **Tests** * Added regression test to validate gradient computation in speculative decoding scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: javierdejesusda <javier.dejesusj9@gmail.com>
1 parent 0a1ca5d commit 5887410

2 files changed

Lines changed: 39 additions & 0 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,8 @@ def forward(
10801080
batch_size, seq_len_s, device=eagle_input_hiddens.device
10811081
).argsort(dim=1)[:, :num_to_replace]
10821082

1083+
# Clone to avoid inplace modification that breaks autograd
1084+
eagle_input_hiddens = eagle_input_hiddens.clone()
10831085
batch_indices = torch.arange(batch_size)[:, None]
10841086
eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
10851087
batch_indices, rand_indices

tests/unit/torch/speculative/plugins/test_hf_speculative.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from copy import deepcopy
1818

1919
import pytest
20+
import torch
2021
from _test_utils.torch.transformers_models import (
2122
get_tiny_llama,
2223
tf_modelopt_state_and_output_tester,
@@ -48,3 +49,39 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config):
4849
model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model")
4950
assert isinstance(model_test, mtsp.plugins.HFEagleModel)
5051
tf_modelopt_state_and_output_tester(model_ref, model_test)
52+
53+
54+
@pytest.mark.parametrize("eagle_config", [EAGLE3_DEFAULT_CFG])
55+
@pytest.mark.parametrize("eagle_ttt_steps", [1, 2])
56+
def test_eagle_mix_hidden_states_backward(eagle_config, eagle_ttt_steps):
57+
"""Regression test for GitHub issue #1088.
58+
59+
Verifies that the EAGLE training forward+backward pass does not crash with
60+
``eagle_mix_hidden_states=True`` due to an in-place tensor modification
61+
breaking autograd.
62+
"""
63+
model = get_tiny_llama(num_hidden_layers=8)
64+
65+
config = deepcopy(eagle_config["config"])
66+
config["eagle_architecture_config"].update(
67+
{
68+
"draft_vocab_size": model.config.vocab_size,
69+
"hidden_size": model.config.hidden_size,
70+
}
71+
)
72+
config["eagle_mix_hidden_states"] = True
73+
config["eagle_ttt_steps"] = eagle_ttt_steps
74+
config["eagle_use_torch_compile"] = False
75+
76+
mtsp.convert(model, mode=[("eagle", config)])
77+
model.train()
78+
79+
input_ids = torch.randint(0, model.config.vocab_size, (2, 16))
80+
labels = input_ids.clone()
81+
82+
outputs = model(input_ids=input_ids, labels=labels)
83+
assert outputs.loss is not None
84+
outputs.loss.backward()
85+
86+
eagle_grads = [p.grad for p in model.eagle_module.parameters() if p.grad is not None]
87+
assert len(eagle_grads) > 0, "Expected gradients to flow to eagle_module"

0 commit comments

Comments
 (0)