Add MXFP8 attention unit test with linear and rope layers#3033
Conversation
Forward/backward Triton kernels for DSv3 671B MLA RoPE, ported from Megatron-LM fused_mla_yarn_rope_apply.py. Falls back to PyTorch when Triton is unavailable. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Tests: Linear(QKV, MXFP8) -> MLA-RoPE -> DotProductAttention(MXFP8) -> Linear(out, MXFP8) against a BF16 baseline for accuracy, backward correctness, and performance. Dimensions: hidden=16384, heads=128, dqk=192 (nope=128+rope=64), dv=128, s=4096, b=1. Weight quantization is amortized via is_first_microbatch caching (pre-quantized weights reused each iteration). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a DSv3-shaped MXFP8 attention unit test covering the training path:
Confidence Score: 5/5Safe to merge; all changes are test-only additions with no impact on production code paths. Both files are purely additive test infrastructure. The RoPE kernel math (forward rotation, backward Jacobian, accumulation of shared k_pos_emb gradient across heads) is correct. The restore_value usage in autotuning is handled correctly. The only discrepancy is that test_performance runs unconditionally in CI while the PR description claims it requires RUN_BENCHMARK_TESTS=1, which may add latency to test runs but does not affect correctness. No files require special attention for correctness. test_linear_mxfp8_attention.py is worth a second look if CI run times are a concern, specifically the test_performance method. Important Files Changed
Sequence DiagramsequenceDiagram
participant X as Input (BF16)
participant QKV as te.Linear (QKV, MXFP8)
participant ROPE as apply_mla_rope (Triton/PyTorch)
participant DPA as te.DotProductAttention (MXFP8)
participant OUT as te.Linear (out, MXFP8)
participant Y as Output
X->>QKV: x [s,b,hidden]
Note over QKV: fp8_autocast context 1
QKV->>ROPE: "qkv [s,b,h*(2*dqk+dv)]"
ROPE->>ROPE: split Q/K/V
ROPE->>ROPE: rotate Q nope+rope in-place
ROPE->>ROPE: broadcast k_pos_emb to all heads
ROPE->>DPA: q,k [s,b,h,dqk], v [s,b,h,dv]
Note over DPA: fp8_autocast context 2
DPA->>OUT: attn_out [s,b,h,dv]
Note over OUT: fp8_autocast context 3
OUT->>Y: out [s,b,hidden]
Reviews (2): Last reviewed commit: "Run MXFP8 attention benchmark by default" | Re-trigger Greptile |
| x_off = tl.arange(0, BLOCK_H)[:, None] * stride_x_nheads + qk_head_dim | ||
| mask = x_off < head_num * stride_x_nheads |
There was a problem hiding this comment.
Mask ignores block offset for partial last head-block
The bound check x_off < head_num * stride_x_nheads compares only the intra-block head index i against the total head count, but the pointer Q has already been advanced by pid_head * BLOCK_H * stride_x_nheads. For the last block when head_num % BLOCK_H != 0, the absolute head index pid_head * BLOCK_H + i may exceed head_num while the mask still evaluates to True, causing an out-of-bounds load/store. The same pattern appears in rotary_bwd_q_kernel (line 178) and in the accumulation loop inside rotary_bwd_kv_kernel (line 332). For DSv3's 128 heads all BLOCK_H candidates (1–128) evenly divide 128, so the current test is unaffected, but any future caller with a non-aligned head count would silently corrupt memory.
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| qkv = qkv_linear(x, is_first_microbatch=is_first_microbatch) | ||
|
|
||
| q, k, v = _split_qkv(qkv) | ||
| q, k, v = apply_mla_rope(q, k, v) | ||
|
|
||
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| attn_out = dpa(q, k, v, qkv_format="sbhd") | ||
|
|
||
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| out = out_linear( | ||
| attn_out.view(x.shape[0], x.shape[1], HIDDEN_SIZE), | ||
| is_first_microbatch=is_first_microbatch, | ||
| ) |
There was a problem hiding this comment.
Three separate
fp8_autocast scopes may reset per-layer FP8 statistics
The forward pipeline is split into three independent fp8_autocast contexts (QKV linear, DPA, out linear). For current MXFP8 block scaling the effect is benign because scales are computed per-block and don't depend on cross-layer statistics, but TE's FP8GlobalStateManager maintains per-forward-pass amax history used by some recipes. Exiting and re-entering the context between layers means each layer is treated as a separate forward pass for bookkeeping purposes, so any inter-layer scale propagation is lost. A single surrounding fp8_autocast context covering all three layers would match the standard training usage and would more faithfully exercise the end-to-end path.
| with torch.no_grad(): | ||
| _run_forward_mxfp8(mxfp8_modules, x, fp8_recipe, is_first_microbatch=True) | ||
|
|
||
| mxfp8_fprop_ms, mxfp8_bprop_ms = _benchmark_training_step( | ||
| _run_forward_mxfp8, mxfp8_modules, x, fp8_recipe, False |
There was a problem hiding this comment.
Warmup backward may invalidate the
is_first_microbatch weight cache before timed iterations
The weight cache is seeded with is_first_microbatch=True inside torch.no_grad(), but the 10 warmup iterations inside _benchmark_training_step call backward() on a new computation graph. TE invalidates or refreshes the cached quantized weight buffer after a backward pass (the cache is keyed to the current "microbatch"). Consequently the timed iterations that pass is_first_microbatch=False may miss the cache on every call and silently fall back to per-iteration weight quantization, making the benchmark measure a different workload than described in the docstring.
Add a DSv3-shaped MXFP8 attention unit test covering the training path:
Linear(QKV) -> MLA RoPE -> DotProductAttention -> Linear(out).DotProductAttentionwrapper.RUN_BF16_REFERENCE=1.Validation
Local checks:
python -m py_compile tests/pytorch/attention/test_linear_mxfp8_attention.py tests/pytorch/attention/mla_rope_utils.pygit diff --checkGB300 dlcluster validation, container:
gitlab-master.nvidia.com/dl/mlperf/optimized:deepseekv3_671b.pytorch.520256791062811(10, 3)(9, 21, 1)(True, '')python -m pytest tests/pytorch/attention/test_linear_mxfp8_attention.py -v -s3 passedPerf output: