Skip to content

Add MXFP8 attention unit test with linear and rope layers#3033

Open
layalir wants to merge 8 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test
Open

Add MXFP8 attention unit test with linear and rope layers#3033
layalir wants to merge 8 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test

Conversation

@layalir
Copy link
Copy Markdown

@layalir layalir commented May 22, 2026

Add a DSv3-shaped MXFP8 attention unit test covering the training path:

  • Adds MLA RoPE utilities for the DSv3 671B attention shape.
  • Adds an end-to-end MXFP8 path: Linear(QKV) -> MLA RoPE -> DotProductAttention -> Linear(out).
  • Exercises MXFP8 forward and backward through TE’s real DotProductAttention wrapper.
  • Keeps BF16 reference comparison optional with RUN_BF16_REFERENCE=1.

Validation

Local checks:

  • python -m py_compile tests/pytorch/attention/test_linear_mxfp8_attention.py tests/pytorch/attention/mla_rope_utils.py
  • git diff --check

GB300 dlcluster validation, container:
gitlab-master.nvidia.com/dl/mlperf/optimized:deepseekv3_671b.pytorch.52025679

  • Job: 1062811
  • GPU: NVIDIA GB300
  • CUDA capability: (10, 3)
  • cuDNN: (9, 21, 1)
  • MXFP8 available: (True, '')
  • Command: python -m pytest tests/pytorch/attention/test_linear_mxfp8_attention.py -v -s
  • Result: 3 passed

Perf output:

[PERF] b=1 s=4096:
  BF16 fprop:  7.219 ms  (567397 tok/s)
  BF16 bprop:  15.179 ms  (269844 tok/s)
  MXFP8 fprop: 4.718 ms  (868181 tok/s)
  MXFP8 bprop: 9.215 ms  (444492 tok/s)
  Fprop speedup: 1.53x
  Bprop speedup: 1.65x

layalir and others added 6 commits May 21, 2026 07:56
Forward/backward Triton kernels for DSv3 671B MLA RoPE, ported from
Megatron-LM fused_mla_yarn_rope_apply.py. Falls back to PyTorch when
Triton is unavailable.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Tests: Linear(QKV, MXFP8) -> MLA-RoPE -> DotProductAttention(MXFP8) -> Linear(out, MXFP8)
against a BF16 baseline for accuracy, backward correctness, and performance.

Dimensions: hidden=16384, heads=128, dqk=192 (nope=128+rope=64), dv=128, s=4096, b=1.

Weight quantization is amortized via is_first_microbatch caching
(pre-quantized weights reused each iteration).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR adds a DSv3-shaped MXFP8 attention unit test covering the training path: Linear(QKV) → MLA RoPE → DotProductAttention → Linear(out). It also introduces Triton-based (with PyTorch fallback) MLA RoPE forward/backward kernels for the DSv3 671B attention shape.

  • mla_rope_utils.py: Triton kernels for in-place Q rotation and KV decomposition/rotation, with a pure-PyTorch fallback. restore_value is correctly applied to Q/DO buffers that are modified in-place during autotuning, while write-only output buffers (dKV, dEMB) correctly omit it.
  • test_linear_mxfp8_attention.py: Three pytest tests — test_accuracy (forward sanity + optional BF16 comparison), test_backward (gradient flow), and test_performance (timing benchmark). Only BF16 comparison is opt-in via RUN_BF16_REFERENCE=1; the performance benchmark itself runs unconditionally despite the PR description claiming it is gated by RUN_BENCHMARK_TESTS=1.

Confidence Score: 5/5

Safe to merge; all changes are test-only additions with no impact on production code paths.

Both files are purely additive test infrastructure. The RoPE kernel math (forward rotation, backward Jacobian, accumulation of shared k_pos_emb gradient across heads) is correct. The restore_value usage in autotuning is handled correctly. The only discrepancy is that test_performance runs unconditionally in CI while the PR description claims it requires RUN_BENCHMARK_TESTS=1, which may add latency to test runs but does not affect correctness.

No files require special attention for correctness. test_linear_mxfp8_attention.py is worth a second look if CI run times are a concern, specifically the test_performance method.

Important Files Changed

Filename Overview
tests/pytorch/attention/mla_rope_utils.py New file: Triton-based (with PyTorch fallback) MLA RoPE forward/backward kernels for DSv3 671B shape. Math is correct; restore_value is correctly applied only where in-place modifications occur during autotuning.
tests/pytorch/attention/test_linear_mxfp8_attention.py New file: MXFP8 end-to-end attention unit test. test_performance has no RUN_BENCHMARK_TESTS guard despite the PR description claiming opt-in gating, so it runs unconditionally in CI.

Sequence Diagram

sequenceDiagram
    participant X as Input (BF16)
    participant QKV as te.Linear (QKV, MXFP8)
    participant ROPE as apply_mla_rope (Triton/PyTorch)
    participant DPA as te.DotProductAttention (MXFP8)
    participant OUT as te.Linear (out, MXFP8)
    participant Y as Output

    X->>QKV: x [s,b,hidden]
    Note over QKV: fp8_autocast context 1
    QKV->>ROPE: "qkv [s,b,h*(2*dqk+dv)]"
    ROPE->>ROPE: split Q/K/V
    ROPE->>ROPE: rotate Q nope+rope in-place
    ROPE->>ROPE: broadcast k_pos_emb to all heads
    ROPE->>DPA: q,k [s,b,h,dqk], v [s,b,h,dv]
    Note over DPA: fp8_autocast context 2
    DPA->>OUT: attn_out [s,b,h,dv]
    Note over OUT: fp8_autocast context 3
    OUT->>Y: out [s,b,hidden]
Loading

Reviews (2): Last reviewed commit: "Run MXFP8 attention benchmark by default" | Re-trigger Greptile

Comment on lines +118 to +119
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_x_nheads + qk_head_dim
mask = x_off < head_num * stride_x_nheads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Mask ignores block offset for partial last head-block

The bound check x_off < head_num * stride_x_nheads compares only the intra-block head index i against the total head count, but the pointer Q has already been advanced by pid_head * BLOCK_H * stride_x_nheads. For the last block when head_num % BLOCK_H != 0, the absolute head index pid_head * BLOCK_H + i may exceed head_num while the mask still evaluates to True, causing an out-of-bounds load/store. The same pattern appears in rotary_bwd_q_kernel (line 178) and in the accumulation loop inside rotary_bwd_kv_kernel (line 332). For DSv3's 128 heads all BLOCK_H candidates (1–128) evenly divide 128, so the current test is unaffected, but any future caller with a non-aligned head count would silently corrupt memory.

Comment on lines +189 to +202
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
qkv = qkv_linear(x, is_first_microbatch=is_first_microbatch)

q, k, v = _split_qkv(qkv)
q, k, v = apply_mla_rope(q, k, v)

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
attn_out = dpa(q, k, v, qkv_format="sbhd")

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
out = out_linear(
attn_out.view(x.shape[0], x.shape[1], HIDDEN_SIZE),
is_first_microbatch=is_first_microbatch,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Three separate fp8_autocast scopes may reset per-layer FP8 statistics

The forward pipeline is split into three independent fp8_autocast contexts (QKV linear, DPA, out linear). For current MXFP8 block scaling the effect is benign because scales are computed per-block and don't depend on cross-layer statistics, but TE's FP8GlobalStateManager maintains per-forward-pass amax history used by some recipes. Exiting and re-entering the context between layers means each layer is treated as a separate forward pass for bookkeeping purposes, so any inter-layer scale propagation is lost. A single surrounding fp8_autocast context covering all three layers would match the standard training usage and would more faithfully exercise the end-to-end path.

Comment on lines +374 to +378
with torch.no_grad():
_run_forward_mxfp8(mxfp8_modules, x, fp8_recipe, is_first_microbatch=True)

mxfp8_fprop_ms, mxfp8_bprop_ms = _benchmark_training_step(
_run_forward_mxfp8, mxfp8_modules, x, fp8_recipe, False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Warmup backward may invalidate the is_first_microbatch weight cache before timed iterations

The weight cache is seeded with is_first_microbatch=True inside torch.no_grad(), but the 10 warmup iterations inside _benchmark_training_step call backward() on a new computation graph. TE invalidates or refreshes the cached quantized weight buffer after a backward pass (the cache is keyed to the current "microbatch"). Consequently the timed iterations that pass is_first_microbatch=False may miss the cache on every call and silently fall back to per-iteration weight quantization, making the benchmark measure a different workload than described in the docstring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant