Skip to content
65 changes: 65 additions & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,56 @@ def _check_configs(self):
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)

# D=256 bprop on SM10.x uses cuDNN's dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+). FE forces this path onto the deterministic algorithm path,
# which rejects dBias, dropout, and ALiBi. It supports vanilla softmax only and allows SWA
# together with a causal mask only.
compute_capability = get_device_compute_capability(0)
is_sm10x = 100 <= compute_capability < 110
if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256):
if self.head_dim_qk != 256 or self.head_dim_v != 256:
pytest.skip(
"D=256 BWD on Blackwell only supports d_qk == d_v == 256;"
f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}."
)
cudnn_version = get_cudnn_version()
if cudnn_version < 92300:
pytest.skip(
"D=256 BWD on Blackwell requires cuDNN 9.23 or newer;"
f" got cuDNN {cudnn_version}."
)
# Non-learnable bias is fine (bias is allowed as an input); only dBias is
# unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s]
# (see test_backward), so gate on that.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
unsupported = (
"bias gradients (dBias); frozen/non-learnable bias inputs"
" (i.e. non-1HSS bias shapes) are supported"
)
Comment on lines +465 to +475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JAX skip logic diverges from C++ backend gate for non-1HSS bias

The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.

elif self.dropout_prob != 0.0:
unsupported = "dropout"
elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
unsupported = "non-vanilla softmax"
if unsupported is not None:
pytest.skip(
"D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD"
f" kernel which does not support {unsupported}."
)
if self.window_size is not None and self.window_size != (-1, -1):
if not self.attn_mask_type.is_causal():
pytest.skip(
"D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel"
" which requires window_size=(-1, -1) for non-causal masks."
)
if self.window_size[1] not in (-1, 0):
pytest.skip(
"D=256 BWD on Blackwell only supports right window -1 or 0"
" for causal masks."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks duplicate to the checks we added on the C++ side? Would the call FusedAttnHelper().get_fused_attn_backend() give you the same gating effect?

self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
Expand Down Expand Up @@ -1474,6 +1524,21 @@ def test_backward(
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
# D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel
# (cuDNN FE 1.24 / BE 9.23+). Unsupported configs (e.g. dBias, non-256 head dims)
# are skipped by FusedAttnRunner._check_configs.
pytest.param(
4,
128,
128,
16,
16,
256,
256,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED",
),
],
)
@pytest.mark.parametrize(
Expand Down
32 changes: 32 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24),
# via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only,
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
Comment on lines +382 to +383
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated comment fragment

The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.

Suggested change
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some editing glitch :)

model_configs_fused_hdim256 = {
# test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256)
"fused_hd256_no_mask": ModelConfig(2, 512, 16, 256),
"fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"),
# SWA is allowed only together with a causal mask on the D=256 bprop kernel.
"fused_hd256_causal_swa": ModelConfig(
2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0)
),
# GQA variant (num_gqa_groups < num_heads).
"fused_hd256_padding_causal_gqa": ModelConfig(
2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal"
),
}


@pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.")
@pytest.mark.skipif(
device_compute_capability not in ((10, 0), (10, 3)),
reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.",
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256])
@pytest.mark.parametrize("model", model_configs_fused_hdim256.keys())
def test_dpa_fused_attn_hdim256(dtype, model_configs, model):
"""Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_mla = {
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
sm_arch_ < 110 && cudnn_runtime_version >= 92300 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
// The FE forces this path onto the deterministic bprop algorithm, which on
// Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only).
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX &&
// Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks.
((window_size_left == -1 && window_size_right == -1) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0)))) ||
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only.

// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
Expand Down
Loading