Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 12 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 12 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

@vcherepanov-nv vcherepanov-nv commented May 13, 2026

Description

Adds experimental JAX fused-attention score_mod support through cuDNN frontend SDPA graphs.

This introduces a score_mod(graph, score, tensors) callback path for fused_attn, plus optional score_mod_bprop(graph, dscore, tensors) support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.

The Flax API now plumbs score_mod through DotProductAttention, MultiHeadAttention, and TransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.

Users are responsible for supplying a mathematically correct score_mod_bprop for the corresponding score_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.

Current score_mod limitations:

  • Requires fused attention to be enabled.
  • Supports separate rank-4 BSHD_BSHD_BSHD Q/K/V tensors only.
  • Supports FP16/BF16 Q/K/V tensors.
  • Mutually exclusive with attention bias, masks, sequence descriptors, dropout, sliding-window attention, packed/ragged metadata, context parallelism, and non-vanilla softmax/softmax offset.
  • Requires matching cuDNN frontend Python package and C++ headers.

Fixes # (issue)
#2492

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new score_mod code path for the JAX FusedAttention backend
  • cuDNN frontend graph serialization and JAX FFI execution for score_mod forward/backward
  • Flax plumbing for DotProductAttention, MultiHeadAttention, and TransformerLayer
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds an experimental cuDNN frontend score_mod path for JAX fused attention, allowing users to inject arbitrary score modifications (causal masks, softcapping, ALiBi-style biases) into SDPA via a Python callback that builds a cuDNN graph at trace time. The Python side serializes and caches the graph; the C++ XLA FFI handler deserializes and caches it per device, then executes through a new variant_pack dispatch.

  • New flex_attention.py: builds, serializes, and caches cuDNN frontend forward/backward SDPA graphs; handles auxiliary tensor operands and pass-by-value scalars; wires into JAX via ffi_call with stable attribute-based graph hashes.
  • C++ FFI handlers (FusedAttnScoreModForward/BackwardFFI): deserialize cuDNN graphs per-device using thread-local handles, assemble the variant pack from positional and variadic buffers, and execute; double-checked locking protects the C++ graph cache.
  • Flax plumbing: propagates score_mod/score_mod_bprop/score_mod_tensors through DotProductAttention, MultiHeadAttention, and TransformerLayer; packed/KV-packed layouts are split to BSHD_BSHD_BSHD when score_mod is active.

Confidence Score: 5/5

This PR is safe to merge — the core execution path is correct with no functional bugs found.

Variant_pack UID-to-buffer mappings are correct, BHSD stride reinterpretation of BSHD tensors is accurate, packed-QKV unpacking covers all three layout branches, and the custom_vjp residuals carry exactly what the backward rule consumes. Version matching between Python and C++ cuDNN frontend is enforced at both graph-build and execution time. All findings are non-blocking style and observability suggestions.

build_tools/jax.py — the silent omission of the cudnn-frontend include dir conflicts with the unconditional #include in attention.cpp and would produce confusing build errors.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/flex_attention.py New 967-line module: cuDNN graph build/serialize/cache logic, version matching, and FFI call wrappers for forward and backward.
transformer_engine/jax/csrc/extensions/attention.cpp Adds 243 lines: C++ ScoreModGraph cache with thread-local cuDNN handles and double-checked locking, FFI handlers, scalar pass-by-value packing.
transformer_engine/jax/attention.py Wires score_mod through a new custom_vjp and adds score_mod/bprop/tensors params to fused_attn(); makes sequence_descriptor Optional.
transformer_engine/jax/flax/transformer.py Plumbs score_mod through DotProductAttention/MultiHeadAttention/TransformerLayer; unpacks packed QKV to BSHD_BSHD_BSHD when score_mod is requested.
tests/jax/test_fused_attn_score_mod.py Comprehensive new tests: causal, softcap, post-scale-bias variants with forward/backward numerical checks, cache key stability, config splitting, and Flax plumbing.
tests/jax/test_distributed_fused_attn.py Adds distributed score_mod test with data/tensor parallel sharding and auxiliary parameter gradient checks.
build_tools/jax.py Optional cudnn-frontend header discovery; silent omission when not found conflicts with unconditional #include in attention.cpp.

Reviews (7): Last reviewed commit: "Address JAX score_mod review feedback" | Re-trigger Greptile

Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/jax/test_fused_attn.py Outdated
vcherepanov-nv and others added 2 commits May 15, 2026 03:35
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/attention.py Outdated
vcherepanov-nv and others added 2 commits May 18, 2026 23:54
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
vcherepanov-nv and others added 2 commits May 19, 2026 00:32
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/attention.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants