Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _get_max_segments_per_sequence(self):

def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
# TDOD(KshitijLakhani): probably add/move this to is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")

Expand All @@ -417,11 +418,59 @@ def _check_configs(self):
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
compute_capability = get_device_compute_capability(0)
cudnn_version = get_cudnn_version()
# D=256 bprop on SM10x (cuDNN FE 1.24 / BE 9.23+) uses the deterministic algorithm path only,
# which rejects dBias, dropout, and ALiBi. It supports vanilla type of softmax only and allows SWA
# together with a causal mask only.
is_sm10x = 100 <= compute_capability < 110
if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256):
if self.head_dim_qk != 256 or self.head_dim_v != 256:
pytest.skip(
"D=256 BWD on Blackwell only supports d_qk == d_v == 256;"
f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}."
)
if cudnn_version < 92300:
pytest.skip(
"D=256 BWD on Blackwell requires cuDNN 9.23 or newer;"
f" got cuDNN {cudnn_version}."
)
# Non-learnable bias is fine (bias is allowed as an input); only dBias is
# unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s]
# (see test_backward), so gate on that.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
unsupported = (
"bias gradients (dBias); frozen/non-learnable bias inputs"
" (i.e. non-1HSS bias shapes) are supported"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JAX skip logic diverges from C++ backend gate for non-1HSS bias

The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b8fe919

elif self.dropout_prob != 0.0:
unsupported = "dropout"
elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
unsupported = "non-vanilla softmax"
if unsupported is not None:
pytest.skip(
"D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD"
f" kernel which does not support {unsupported}."
)
if self.window_size is not None and self.window_size != (-1, -1):
if not self.attn_mask_type.is_causal():
pytest.skip(
"D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel"
" which requires window_size=(-1, -1) for non-causal masks."
)
if self.window_size[1] not in (-1, 0):
pytest.skip(
"D=256 BWD on Blackwell only supports right window -1 or 0"
" for causal masks."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks duplicate to the checks we added on the C++ side? Would the call FusedAttnHelper().get_fused_attn_backend() give you the same gating effect?

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we are just interested in the gating effect, you are right. The get_fused_attn_backend() will return NVTE_No_Backend and then there's a catch-all at the end which basically skip the tests as there is no fused attn backend avalable.

However, the reason for this to be here is to give a meaningful reason as to why a test is being skipped as compared to a generic "Unsupported inputs combination or device compute capability." message which does not qualify the reason for the skip. Unfortunately, on the JAX attn side we do not log the reason for disabling fused attn in the feature code like we have on the Pytorch side in d_p_a/utils.py. So there is no way for the user to know why the test was skipped. Hence, we need to rely on test code to log this on the JAX side.

I'd suggest we leave this in here for now. And when your PR for generating log messages in the C++ level when selecting the attn backend is ready, I can plumb it through onto the JAX side and then as part of that clean up, get rid of all the skip messages in check_configs()

if get_device_compute_capability(0) >= 100 and self.is_training:
if compute_capability >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
or cudnn_version < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
Expand All @@ -430,7 +479,7 @@ def _check_configs(self):
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
or cudnn_version < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
Expand Down Expand Up @@ -1474,6 +1523,36 @@ def test_backward(
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
# D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+).
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED",
),
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.THD_T2HD,
id="4-128-128-16-16-256-256-FP16-SELF-RAGGED_KV_PACKED",
marks=pytest.mark.xfail(
reason="cuDNN 9.23 D=256 BWD currently does not build a THD execution plan.",
strict=True,
),
),
],
)
@pytest.mark.parametrize(
Expand Down
32 changes: 32 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24),
# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only,
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated comment fragment

The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.

Suggested change
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some editing glitch :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 379242b

model_configs_fused_hdim256 = {
# test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256)
"fused_hd256_no_mask": ModelConfig(2, 512, 16, 256),
"fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"),
# SWA is allowed only together with a causal mask on the D=256 bprop kernel.
"fused_hd256_causal_swa": ModelConfig(
2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0)
),
# GQA variant (num_gqa_groups < num_heads).
"fused_hd256_padding_causal_gqa": ModelConfig(
2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"
),
}


@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.")
@pytest.mark.skipif(
device_compute_capability not in ((10, 0), (10, 3)),
reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256])
@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys())
def test_dpa_fused_attn_hdim256(dtype, model_configs, model):
"""Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_mla = {
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
sm_arch_ < 110 && cudnn_runtime_version >= 92300 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
// The FE forces this path onto the deterministic bprop algorithm, which on
// Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only).
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX &&
// Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks.
((window_size_left == -1 && window_size_right == -1) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0)))) ||
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: THD support
I did test BSHD and BSHD+CP and it did pass on the JAX side and the CI for the PyT side did not fail either so I think that works.
My testing revealed that THD support is not yet available (Bwd plan compialtion issue) so I've filed a bug and shared a reproducer for the same with the cuDNN team: NVIDIA/cudnn-frontend#276

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?

Fixed in a264de1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 D=256 BWD condition includes THD layout, causing a hard runtime exception

The new condition only excludes NVTE_Paged_KV_HD_HD_HD but does not exclude THD-format layouts. NVTE_THD_T2HD maps to layout_group = NVTE_HD_2HD and qkv_format = NVTE_THD, both of which pass all guards here and in the outer flag_arb checks (the qkv_format check at line 417 allows THD when sm_arch_ >= 90, which is true for SM10x). So nvte_get_fused_attn_backend returns NVTE_F16_arbitrary_seqlen for full-window THD + D=256 + SM10x + cuDNN ≥ 9.23, claiming support — but cuDNN 9.23 fails to build an execution plan for this layout, and NVTE_CHECK_CUDNN_FE on lines 421–422 of fused_attn_f16_arbitrary_seqlen.cu will throw a hard exception. The JAX xfail test documents the failure, but any production user with THD + D=256 training will hit an unrecoverable runtime error rather than a graceful backend fallback. Adding qkv_format != NVTE_QKV_Format::NVTE_THD to this condition would fix the backend selector; the JAX xfail test would then SKIP instead of XFAIL (which could be separately handled if you want to preserve the sentinel behaviour).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This support is forwarding looking, i.e., support is added for THD and BSHD in the TE common fused attn backend checking code, however, the PR is still waiting on cuDNN to fix support for THD.
The current PR will not be merged as is. One of two things will happen:

  1. cuDNN will fix THD support and only then will this PR be merged (most likely) - after fixing the XFAIL for THD cases to skips for a specific cuDNNv version
  2. cuDNN will not fix this soon in which case I will switch the support to BSHD only prior to merging this PR

// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
Expand Down
Loading