Add the getter and setter of skip_fp8_weight_update_tensor#3015
Add the getter and setter of skip_fp8_weight_update_tensor#3015xrennvidia wants to merge 4 commits into
Conversation
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR restores the
Confidence Score: 5/5Safe to merge. The change is a small, targeted restoration of two accessor methods with no logic changes to the underlying tensor lifecycle. Both call sites in graph.py are updated to use the new setter, and the setter itself is an exact extraction of the logic that was previously inlined — no behavioral difference. The getter simply exposes what was already a public dataclass field. The only loose end is three existing callers in module files that still access the field directly, but those predate this PR and are not regressed by it. No files require special attention. module/linear.py, module/layernorm_mlp.py, and module/layernorm_linear.py still access the internal field directly, but that is pre-existing and out of scope for this PR. Important Files Changed
Sequence DiagramsequenceDiagram
participant MCore as MCore CudaGraph
participant FP8GSM as FP8GlobalStateManager
participant QState as FP8GlobalState (quantization_state)
MCore->>FP8GSM: set_skip_fp8_weight_update_tensor(True/False)
alt tensor is None
FP8GSM->>QState: create skip_fp8_weight_update_tensor (CUDA float32)
end
FP8GSM->>QState: fill_(skip)
MCore->>FP8GSM: get_skip_fp8_weight_update_tensor()
FP8GSM->>QState: read skip_fp8_weight_update_tensor
QState-->>MCore: Optional[torch.Tensor]
Note over FP8GSM,QState: graph.py also calls set_skip_fp8_weight_update_tensor during graphed callable setup and forward pass
Reviews (3): Last reviewed commit: "Merge branch 'main' into xren/fix_skip_f..." | Re-trigger Greptile |
return type fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
ptrendx
left a comment
There was a problem hiding this comment.
I believe there could be a reason why Pawel removed those functions from this object and we may need to change MCore instead in order to have this be compatible with torch.compile. Setting 'request changes' status for now until @pggPL reviews it.
Description
The getter and setter of
skip_fp8_weight_update_tensorwere deleted in @pggPL 's PR2759, but MCore local Cuda Graph implementation still needs it (like here), so create this PR to recover it back.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: