Commit 5887410
authored
### Type of change
- [x] Bug fix (non-breaking change which fixes an issue)
### Description
Fixes #1088 — `RuntimeError: one of the variables needed for gradient
computation has been modified by an inplace operation:
IndexPutBackward0` when training with `eagle_mix_hidden_states=True`.
**Root cause:** In `HFEagleModel._eagle_training_forward`, the indexed
assignment at line 991–994 modifies `eagle_input_hiddens` in-place while
it is still part of the autograd computation graph.
**Fix:** Clone the tensor before the in-place assignment. This is the
same pattern already used in the Megatron backend at
`megatron_eagle.py:1201-1202`:
```python
# Clone to avoid inplace modification of view created in no_grad mode
eagle_module_input_hidden_states = eagle_module_input_hidden_states.clone()
```
The HF backend was missing this clone.
### Usage
```python
config["eagle_mix_hidden_states"] = True
config["eagle_ttt_steps"] = 2
mtsp.convert(model, mode=[("eagle", config)])
model.train()
outputs = model(input_ids=input_ids, labels=labels)
outputs.loss.backward() # no longer crashes
```
### Testing
Added `test_eagle_mix_hidden_states_backward` parametrized over
`eagle_ttt_steps` [1, 2] that:
- Converts a tiny LLaMA to EAGLE with `eagle_mix_hidden_states=True`
- Runs forward + backward pass
- Asserts loss is not None and gradients flow to `eagle_module`
```
pytest tests/unit/torch/speculative/plugins/test_hf_speculative.py::test_eagle_mix_hidden_states_backward -v
```
### Checklist
- [x] I have read the [contributor guidelines](CONTRIBUTING.md) and
signed my commits
- [x] I have followed the [security best practices](SECURITY.md)
- [x] This change is backward compatible
- [x] I have followed third-party code and dependency guidelines
- [x] I have added tests that prove my fix is effective
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Fixed gradient computation issue in speculative decoding during model
training to ensure proper autograd behavior.
* **Tests**
* Added regression test to validate gradient computation in speculative
decoding scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: javierdejesusda <javier.dejesusj9@gmail.com>
1 parent 0a1ca5d commit 5887410
2 files changed
Lines changed: 39 additions & 0 deletions
File tree
- modelopt/torch/speculative/plugins
- tests/unit/torch/speculative/plugins
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1080 | 1080 | | |
1081 | 1081 | | |
1082 | 1082 | | |
| 1083 | + | |
| 1084 | + | |
1083 | 1085 | | |
1084 | 1086 | | |
1085 | 1087 | | |
| |||
Lines changed: 37 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
48 | 49 | | |
49 | 50 | | |
50 | 51 | | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
0 commit comments