[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds experimental cuDNN-frontend-backed flex attention (
Confidence Score: 5/5Safe to merge as an experimental feature; all flagged items are non-blocking quality improvements with no correctness impact. The core forward and backward graph building, caching, FFI dispatch, and Flax plumbing are all structurally correct. Cache key stability, UID ordering, and pytree gradient structure are handled properly. The findings are race conditions that produce at worst redundant work (not wrong results) and a shutdown-order concern for the thread-local cuDNN handle that matches patterns already present elsewhere in the codebase. transformer_engine/jax/csrc/extensions/attention.cpp (double-checked locking in GetScoreModGraph, thread-local handle destructor ordering) and transformer_engine/jax/cpp_extensions/flex_attention.py (Python-level cache lock). Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant fused_attn
participant ScoreMod as "_fused_attn_score_mod"
participant FlexPy as "flex_attention.py"
participant FFI as "FFI/XLA"
participant Cpp as "C++ Handler"
participant Cache as "cuDNN Graph Cache"
User->>fused_attn: "call with score_mod callback"
fused_attn->>fused_attn: "validate_fused_attn_score_mod()"
fused_attn->>FlexPy: "make_fused_attn_score_mod_config()"
fused_attn->>ScoreMod: "custom_vjp forward"
Note over ScoreMod,FlexPy: JAX Tracing Phase
ScoreMod->>FlexPy: "fused_attn_score_mod_fwd()"
FlexPy->>FlexPy: "check _score_mod_graph_cache"
alt cache miss
FlexPy->>FlexPy: "_build_score_mod_fwd_graph()"
FlexPy->>FlexPy: "store in _score_mod_graph_cache"
end
FlexPy->>FFI: "ffi.ffi_call(serialized_graph, uids)"
Note over FFI,Cache: XLA Execution Phase
FFI->>Cpp: "FusedAttnScoreModForwardFFI(stream, q, k, v)"
Cpp->>Cache: "GetScoreModGraph(stream, attrs)"
alt C++ cache miss
Cache->>Cache: "graph->deserialize(handle, data)"
Cache->>Cache: "store shared_ptr in map"
end
Cpp->>Cpp: "graph->execute(handle, variant_pack)"
Cpp-->>FFI: "output, stats, workspace"
Note over ScoreMod,FlexPy: Backward pass
ScoreMod->>FlexPy: "fused_attn_score_mod_bwd(qkv, o, dO, stats)"
FlexPy->>FFI: "ffi.ffi_call(serialized_bwd_graph)"
FFI->>Cpp: "FusedAttnScoreModBackwardFFI(...)"
Cpp-->>FFI: "dq, dk, dv"
Reviews (9): Last reviewed commit: "Skip softcap score-mod test before SM90" | Re-trigger Greptile |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| score_mod: Optional[Callable] = None, | ||
| score_mod_bprop: Optional[Callable] = None, | ||
| score_mod_tensors: Optional[Mapping[str, Any]] = None, | ||
| score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, | ||
| ): |
There was a problem hiding this comment.
Looks like this is the highest API that score_mod has been plumbed to.
There are higher APIs that would need to be plumbed to as well - please do take a look
At the very least FusedDPA and DPA
There was a problem hiding this comment.
Left some comments for this
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if post_scale_bias: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: | ||
| scores = softcap * jnp.tanh(scores / softcap) | ||
| probs = jax.nn.softmax(scores, axis=-1) | ||
| return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) | ||
|
|
||
|
|
There was a problem hiding this comment.
WHy create your own reference and not use the jax native reference in the test file already ?
There was a problem hiding this comment.
Just to clarify, I believe the idea behind this comment was to check the possibility of code reuse. However, it seems like the solutions chosen is to move the contents of the flex attention tests to a different file altogether.
I do not think it is a good practice for us to have different ways of creating the reference for different types of fused attn. It would have been best to use the reference implementation in test_fused_attn.py tweaked for the test case we have fro flex attention.
However, I'll try to not hold the PR for this but I definitely think this should be looked into a follow up PR
cc: @cyanguwa
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| if score_mod_requested: | ||
| if not enable_fused_attn: | ||
| raise ValueError("score_mod requires fused attention, but NVTE_FUSED_ATTN=0.") | ||
| has_fused_attn_kernel = True |
There was a problem hiding this comment.
Why do we force this to True ?
If user wants to perform flex attn and has enabled fused attn, then shouldn't we check via is_fused_attn_kernel_available() right ?
This seems to be our contract API to determine whether fused is available+requested so I think we should rely on that API rather than force it to True here
| score_mod_requested = ( | ||
| self.score_mod is not None | ||
| or self.score_mod_bprop is not None | ||
| or score_mod_tensors is not None | ||
| or score_mod_bprop_tensors is not None | ||
| ) | ||
|
|
There was a problem hiding this comment.
Why repeat this check in DPA and _FusedDPA ?
The expectation is that _FusedDPA is an internal class (with the leading underscore indicating the same) and so if the checks exist in DPA I donot think we need them again in _FusedDPA. Thoughts ?
The user should be exposed to DPA only IIRC
| score_mod: Optional[Callable] = None, | ||
| score_mod_bprop: Optional[Callable] = None, | ||
| score_mod_tensors: Optional[Mapping[str, Any]] = None, | ||
| score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, | ||
| ): |
There was a problem hiding this comment.
Left some comments for this
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if post_scale_bias: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: | ||
| scores = softcap * jnp.tanh(scores / softcap) | ||
| probs = jax.nn.softmax(scores, axis=-1) | ||
| return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) | ||
|
|
||
|
|
There was a problem hiding this comment.
Just to clarify, I believe the idea behind this comment was to check the possibility of code reuse. However, it seems like the solutions chosen is to move the contents of the flex attention tests to a different file altogether.
I do not think it is a good practice for us to have different ways of creating the reference for different types of fused attn. It would have been best to use the reference implementation in test_fused_attn.py tweaked for the test case we have fro flex attention.
However, I'll try to not hold the PR for this but I definitely think this should be looked into a follow up PR
cc: @cyanguwa
| query = (0.125 * runner.q).astype(dtype) | ||
| key_tensor = (0.125 * runner.k).astype(dtype) | ||
| value = (0.125 * runner.v).astype(dtype) | ||
| doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype) |
There was a problem hiding this comment.
Why do we do this ? 0.125*
| runner = FusedAttnRunner( | ||
| batch, | ||
| seqlen, | ||
| seqlen, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim, | ||
| AttnBiasType.NO_BIAS, | ||
| AttnMaskType.NO_MASK, | ||
| AttnSoftmaxType.VANILLA_SOFTMAX, | ||
| 0.0, | ||
| dtype, | ||
| True, | ||
| QKVLayout.BSHD_BSHD_BSHD, | ||
| None, | ||
| None, | ||
| SeqDescFormat.Mask, | ||
| number_of_devices=device_count, | ||
| mesh_shape=mesh_shape, | ||
| mesh_axes=mesh_axes, | ||
| mesh_resource=mesh_resource, | ||
| ) | ||
| runner._setup_inputs() | ||
|
|
There was a problem hiding this comment.
So it seems like you are using the runner to only setup the inputs but then are following that up with "duplicate" code that test_forward and test_backward in test_fused_attn.py
The suggestions was to try and use the runner to call forward(), which does the setup using the runner and also runs the test.
The idea is if you can integrate the non distributed tests with the Runner infrastrucutre then the distributed tests can directly use it here for free. The approach here seems to be somewhere in between your older approach and the suggested approach. You can refer to other tests in this file for reference, incase I've not done a good job explaining
i'm curiosu to know if the reason for that was because you were unable to fully integrate the score mod tests into the Fused Attn runner ? Because it seems like it is the same reason for creating a separate test_fused_attn_score_mod.py as compared to integrating the flex attn tests in the Fused attn runner in test_fused_attn.py
If test_fused_attn_score_mod must be created then a Runner should be ceated in the too and the distributed flex attn tests can then use that runner in here (similar to how other tests do)
| qkv_sharding = NamedSharding(runner.mesh, PartitionSpec(dp_axis, None, tp_axis, None)) | ||
| query = (0.125 * runner.q).astype(dtype) | ||
| key_tensor = (0.125 * runner.k).astype(dtype) | ||
| value = (0.125 * runner.v).astype(dtype) | ||
| doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype) | ||
|
|
||
| scaling_factor = runner.scaling_factor | ||
| softcap = 0.8 | ||
| softcap_score_mod = _ScoreModSoftcap() | ||
|
|
||
| def score_mod_loss(q, k, v, dout): | ||
| out = customcall_fused_dpa( | ||
| q, | ||
| k, | ||
| v, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| attn_bias_type=AttnBiasType.NO_BIAS, | ||
| attn_mask_type=AttnMaskType.NO_MASK, | ||
| qkv_layout=QKVLayout.BSHD_BSHD_BSHD, | ||
| softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, | ||
| scaling_factor=scaling_factor, | ||
| dropout_probability=0.0, | ||
| is_training=True, | ||
| score_mod=softcap_score_mod.forward, | ||
| score_mod_bprop=softcap_score_mod.backward, | ||
| score_mod_tensors={"softcap": softcap}, | ||
| score_mod_bprop_tensors={"softcap": softcap}, | ||
| ) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| def ref_loss(q, k, v, dout): | ||
| out = _reference_attention(q, k, v, scaling_factor, softcap=softcap) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| jitted_score_mod = jax.jit( | ||
| jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True), | ||
| in_shardings=( | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| ), | ||
| out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)), | ||
| ) | ||
| jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)) | ||
|
|
||
| sharded_args = ( | ||
| jax.device_put(query, qkv_sharding), | ||
| jax.device_put(key_tensor, qkv_sharding), | ||
| jax.device_put(value, qkv_sharding), | ||
| jax.device_put(doutput, qkv_sharding), | ||
| ) |
There was a problem hiding this comment.
All of this can come for free if the flex attn is integrated with the FusedAttn Runner
Description
Adds experimental JAX fused-attention
score_modsupport through cuDNN frontend SDPA graphs.This introduces a
score_mod(graph, score, tensors)callback path forfused_attn, plus optionalscore_mod_bprop(graph, dscore, tensors)support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.The Flax API now plumbs
score_modthroughDotProductAttention,MultiHeadAttention, andTransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.Users are responsible for supplying a mathematically correct
score_mod_bpropfor the correspondingscore_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.Current score_mod limitations:
BSHD_BSHD_BSHDQ/K/V tensors only.Fixes # (issue)
#2492
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: