Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 14 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 14 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

@vcherepanov-nv vcherepanov-nv commented May 13, 2026

Description

Adds experimental JAX fused-attention score_mod support through cuDNN frontend SDPA graphs.

This introduces a score_mod(graph, score, tensors) callback path for fused_attn, plus optional score_mod_bprop(graph, dscore, tensors) support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.

The Flax API now plumbs score_mod through DotProductAttention, MultiHeadAttention, and TransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.

Users are responsible for supplying a mathematically correct score_mod_bprop for the corresponding score_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.

Current score_mod limitations:

  • Requires fused attention to be enabled.
  • Supports separate rank-4 BSHD_BSHD_BSHD Q/K/V tensors only.
  • Supports FP16/BF16 Q/K/V tensors.
  • Mutually exclusive with attention bias, masks, sequence descriptors, dropout, sliding-window attention, packed/ragged metadata, context parallelism, and non-vanilla softmax/softmax offset.
  • Requires matching cuDNN frontend Python package and C++ headers.

Fixes # (issue)
#2492

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new score_mod code path for the JAX FusedAttention backend
  • cuDNN frontend graph serialization and JAX FFI execution for score_mod forward/backward
  • Flax plumbing for DotProductAttention, MultiHeadAttention, and TransformerLayer
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds experimental cuDNN-frontend-backed flex attention (score_mod) to the JAX backend, including graph serialization/deserialization, a Python-level graph cache, new C++ FFI handlers for forward and backward, and Flax plumbing through DotProductAttention, MultiHeadAttention, and TransformerLayer.

  • Core path: fused_attn short-circuits to a new _fused_attn_score_mod custom_vjp primitive when score_mod is provided; the Python side builds and serializes cuDNN frontend graphs at trace time (cached by shape/dtype/config key), then passes the serialized bytes + UID maps as static FFI attributes to C++ handlers that deserialize and execute them.
  • C++ side: Two new FFI handlers deserialize graphs on demand into a process-lifetime unordered_map guarded by a mutex, with a thread-local cuDNN handle cache; the current double-checked locking leaves a window for redundant concurrent deserialization.
  • Flax plumbing: Packed and KV-packed layouts are transparently converted to separate BSHD tensors before the score_mod path; score_mod_tensors / score_mod_bprop_tensors are forwarded as call-time arguments to keep tensor operands in the JAX computation graph.

Confidence Score: 5/5

Safe to merge as an experimental feature; all flagged items are non-blocking quality improvements with no correctness impact.

The core forward and backward graph building, caching, FFI dispatch, and Flax plumbing are all structurally correct. Cache key stability, UID ordering, and pytree gradient structure are handled properly. The findings are race conditions that produce at worst redundant work (not wrong results) and a shutdown-order concern for the thread-local cuDNN handle that matches patterns already present elsewhere in the codebase.

transformer_engine/jax/csrc/extensions/attention.cpp (double-checked locking in GetScoreModGraph, thread-local handle destructor ordering) and transformer_engine/jax/cpp_extensions/flex_attention.py (Python-level cache lock).

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/flex_attention.py New 967-line file implementing cuDNN frontend score_mod graph building, caching, and FFI dispatch; implements a stable cache-key scheme and separates tensor vs. scalar operands cleanly.
transformer_engine/jax/csrc/extensions/attention.cpp Adds 251 lines for C++ cuDNN graph deserialization, thread-local handle cache, and two new FFI handlers (forward/backward); double-checked locking leaves redundant deserializations possible under thread contention.
transformer_engine/jax/attention.py Adds custom_vjp wrapper for score_mod path with correct residual propagation, early-return before the deprecated sequence_descriptor path, and proper validation delegation.
transformer_engine/jax/flax/transformer.py Plumbs score_mod/score_mod_bprop through DotProductAttention, MultiHeadAttention, and TransformerLayer; handles packed/kvpacked layout unpacking correctly before the score_mod path.
tests/jax/test_fused_attn_score_mod.py New 671-line test suite covering causal masking, post-scale bias, softcap (forward/backward), and Flax layer integration, with reference implementations for correctness comparison.

Sequence Diagram

sequenceDiagram
    participant User
    participant fused_attn
    participant ScoreMod as "_fused_attn_score_mod"
    participant FlexPy as "flex_attention.py"
    participant FFI as "FFI/XLA"
    participant Cpp as "C++ Handler"
    participant Cache as "cuDNN Graph Cache"

    User->>fused_attn: "call with score_mod callback"
    fused_attn->>fused_attn: "validate_fused_attn_score_mod()"
    fused_attn->>FlexPy: "make_fused_attn_score_mod_config()"
    fused_attn->>ScoreMod: "custom_vjp forward"

    Note over ScoreMod,FlexPy: JAX Tracing Phase
    ScoreMod->>FlexPy: "fused_attn_score_mod_fwd()"
    FlexPy->>FlexPy: "check _score_mod_graph_cache"
    alt cache miss
        FlexPy->>FlexPy: "_build_score_mod_fwd_graph()"
        FlexPy->>FlexPy: "store in _score_mod_graph_cache"
    end
    FlexPy->>FFI: "ffi.ffi_call(serialized_graph, uids)"

    Note over FFI,Cache: XLA Execution Phase
    FFI->>Cpp: "FusedAttnScoreModForwardFFI(stream, q, k, v)"
    Cpp->>Cache: "GetScoreModGraph(stream, attrs)"
    alt C++ cache miss
        Cache->>Cache: "graph->deserialize(handle, data)"
        Cache->>Cache: "store shared_ptr in map"
    end
    Cpp->>Cpp: "graph->execute(handle, variant_pack)"
    Cpp-->>FFI: "output, stats, workspace"

    Note over ScoreMod,FlexPy: Backward pass
    ScoreMod->>FlexPy: "fused_attn_score_mod_bwd(qkv, o, dO, stats)"
    FlexPy->>FFI: "ffi.ffi_call(serialized_bwd_graph)"
    FFI->>Cpp: "FusedAttnScoreModBackwardFFI(...)"
    Cpp-->>FFI: "dq, dk, dv"
Loading

Reviews (9): Last reviewed commit: "Skip softcap score-mod test before SM90" | Re-trigger Greptile

Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/jax/test_fused_attn.py Outdated
vcherepanov-nv and others added 2 commits May 15, 2026 03:35
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/attention.py Outdated
vcherepanov-nv and others added 2 commits May 18, 2026 23:54
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
vcherepanov-nv and others added 2 commits May 19, 2026 00:32
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +1531 to 1535
score_mod: Optional[Callable] = None,
score_mod_bprop: Optional[Callable] = None,
score_mod_tensors: Optional[Mapping[str, Any]] = None,
score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is the highest API that score_mod has been plumbed to.
There are higher APIs that would need to be plumbed to as well - please do take a look
At the very least FusedDPA and DPA

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments for this

Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment on lines +191 to +208
def _reference_attention(
query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None
):
scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale
if causal:
q_pos = jnp.arange(query.shape[1])[:, None]
kv_pos = jnp.arange(key.shape[1])[None, :]
scores = jnp.where(q_pos >= kv_pos, scores, -1e9)
if post_scale_bias:
q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None]
kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :]
scores = scores + q_pos - kv_pos
if softcap is not None:
scores = softcap * jnp.tanh(scores / softcap)
probs = jax.nn.softmax(scores, axis=-1)
return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype)


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHy create your own reference and not use the jax native reference in the test file already ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, I believe the idea behind this comment was to check the possibility of code reuse. However, it seems like the solutions chosen is to move the contents of the flex attention tests to a different file altogether.

I do not think it is a good practice for us to have different ways of creating the reference for different types of fused attn. It would have been best to use the reference implementation in test_fused_attn.py tweaked for the test case we have fro flex attention.

However, I'll try to not hold the PR for this but I definitely think this should be looked into a follow up PR
cc: @cyanguwa

Comment thread tests/jax/test_fused_attn.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/attention.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@vcherepanov-nv vcherepanov-nv removed the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
if score_mod_requested:
if not enable_fused_attn:
raise ValueError("score_mod requires fused attention, but NVTE_FUSED_ATTN=0.")
has_fused_attn_kernel = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we force this to True ?
If user wants to perform flex attn and has enabled fused attn, then shouldn't we check via is_fused_attn_kernel_available() right ?
This seems to be our contract API to determine whether fused is available+requested so I think we should rely on that API rather than force it to True here

Comment on lines +748 to 754
score_mod_requested = (
self.score_mod is not None
or self.score_mod_bprop is not None
or score_mod_tensors is not None
or score_mod_bprop_tensors is not None
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why repeat this check in DPA and _FusedDPA ?
The expectation is that _FusedDPA is an internal class (with the leading underscore indicating the same) and so if the checks exist in DPA I donot think we need them again in _FusedDPA. Thoughts ?
The user should be exposed to DPA only IIRC

Comment on lines +1531 to 1535
score_mod: Optional[Callable] = None,
score_mod_bprop: Optional[Callable] = None,
score_mod_tensors: Optional[Mapping[str, Any]] = None,
score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments for this

Comment thread tests/jax/test_fused_attn.py Outdated
Comment on lines +191 to +208
def _reference_attention(
query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None
):
scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale
if causal:
q_pos = jnp.arange(query.shape[1])[:, None]
kv_pos = jnp.arange(key.shape[1])[None, :]
scores = jnp.where(q_pos >= kv_pos, scores, -1e9)
if post_scale_bias:
q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None]
kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :]
scores = scores + q_pos - kv_pos
if softcap is not None:
scores = softcap * jnp.tanh(scores / softcap)
probs = jax.nn.softmax(scores, axis=-1)
return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype)


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, I believe the idea behind this comment was to check the possibility of code reuse. However, it seems like the solutions chosen is to move the contents of the flex attention tests to a different file altogether.

I do not think it is a good practice for us to have different ways of creating the reference for different types of fused attn. It would have been best to use the reference implementation in test_fused_attn.py tweaked for the test case we have fro flex attention.

However, I'll try to not hold the PR for this but I definitely think this should be looked into a follow up PR
cc: @cyanguwa

Comment on lines +348 to +351
query = (0.125 * runner.q).astype(dtype)
key_tensor = (0.125 * runner.k).astype(dtype)
value = (0.125 * runner.v).astype(dtype)
doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this ? 0.125*

Comment on lines +322 to +346
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_heads,
num_heads,
head_dim,
head_dim,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
AttnSoftmaxType.VANILLA_SOFTMAX,
0.0,
dtype,
True,
QKVLayout.BSHD_BSHD_BSHD,
None,
None,
SeqDescFormat.Mask,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
)
runner._setup_inputs()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems like you are using the runner to only setup the inputs but then are following that up with "duplicate" code that test_forward and test_backward in test_fused_attn.py
The suggestions was to try and use the runner to call forward(), which does the setup using the runner and also runs the test.

The idea is if you can integrate the non distributed tests with the Runner infrastrucutre then the distributed tests can directly use it here for free. The approach here seems to be somewhere in between your older approach and the suggested approach. You can refer to other tests in this file for reference, incase I've not done a good job explaining

i'm curiosu to know if the reason for that was because you were unable to fully integrate the score mod tests into the Fused Attn runner ? Because it seems like it is the same reason for creating a separate test_fused_attn_score_mod.py as compared to integrating the flex attn tests in the Fused attn runner in test_fused_attn.py

If test_fused_attn_score_mod must be created then a Runner should be ceated in the too and the distributed flex attn tests can then use that runner in here (similar to how other tests do)

Comment on lines +347 to +403
qkv_sharding = NamedSharding(runner.mesh, PartitionSpec(dp_axis, None, tp_axis, None))
query = (0.125 * runner.q).astype(dtype)
key_tensor = (0.125 * runner.k).astype(dtype)
value = (0.125 * runner.v).astype(dtype)
doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype)

scaling_factor = runner.scaling_factor
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = customcall_fused_dpa(
q,
k,
v,
None,
None,
None,
None,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.NO_MASK,
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor=scaling_factor,
dropout_probability=0.0,
is_training=True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can come for free if the flex attn is integrated with the FusedAttn Runner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants