-
Notifications
You must be signed in to change notification settings - Fork 740
[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
84a8e11
0d30f41
c214b36
24a7d2c
195b72c
1b03f3e
580af29
4474eda
e57ca7d
927fab4
e08e9e8
b8fe919
379242b
a264de1
5063b39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -394,6 +394,7 @@ def _get_max_segments_per_sequence(self): | |
|
|
||
| def _check_configs(self): | ||
| # TODO(rewang): probably adds this in is_fused_attn_available | ||
| # TDOD(KshitijLakhani): probably add/move this to is_fused_attn_available | ||
| if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): | ||
| pytest.skip("THD format requires padding masks.") | ||
|
|
||
|
|
@@ -417,11 +418,59 @@ def _check_configs(self): | |
| pytest.skip( | ||
| "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" | ||
| ) | ||
| compute_capability = get_device_compute_capability(0) | ||
| cudnn_version = get_cudnn_version() | ||
| # D=256 bprop on SM10x (cuDNN FE 1.24 / BE 9.23+) uses the deterministic algorithm path only, | ||
| # which rejects dBias, dropout, and ALiBi. It supports vanilla type of softmax only and allows SWA | ||
| # together with a causal mask only. | ||
| is_sm10x = 100 <= compute_capability < 110 | ||
| if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256): | ||
| if self.head_dim_qk != 256 or self.head_dim_v != 256: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell only supports d_qk == d_v == 256;" | ||
| f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}." | ||
| ) | ||
| if cudnn_version < 92300: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell requires cuDNN 9.23 or newer;" | ||
| f" got cuDNN {cudnn_version}." | ||
| ) | ||
| # Non-learnable bias is fine (bias is allowed as an input); only dBias is | ||
| # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] | ||
| # (see test_backward), so gate on that. | ||
| unsupported = None | ||
| if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | ||
| unsupported = "pre-scale bias" | ||
| elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: | ||
| unsupported = ( | ||
| "bias gradients (dBias); frozen/non-learnable bias inputs" | ||
| " (i.e. non-1HSS bias shapes) are supported" | ||
| ) | ||
| elif self.dropout_prob != 0.0: | ||
| unsupported = "dropout" | ||
| elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: | ||
| unsupported = "non-vanilla softmax" | ||
| if unsupported is not None: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD" | ||
| f" kernel which does not support {unsupported}." | ||
| ) | ||
| if self.window_size is not None and self.window_size != (-1, -1): | ||
| if not self.attn_mask_type.is_causal(): | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel" | ||
| " which requires window_size=(-1, -1) for non-causal masks." | ||
| ) | ||
| if self.window_size[1] not in (-1, 0): | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell only supports right window -1 or 0" | ||
| " for causal masks." | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't these checks duplicate to the checks we added on the C++ side? Would the call
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So if we are just interested in the gating effect, you are right. The However, the reason for this to be here is to give a meaningful reason as to why a test is being skipped as compared to a generic "Unsupported inputs combination or device compute capability." message which does not qualify the reason for the skip. Unfortunately, on the JAX attn side we do not log the reason for disabling fused attn in the feature code like we have on the Pytorch side in I'd suggest we leave this in here for now. And when your PR for generating log messages in the C++ level when selecting the attn backend is ready, I can plumb it through onto the JAX side and then as part of that clean up, get rid of all the skip messages in check_configs() |
||
| if get_device_compute_capability(0) >= 100 and self.is_training: | ||
| if compute_capability >= 100 and self.is_training: | ||
| if FusedAttnHelper.is_non_deterministic_allowed() and ( | ||
| (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) | ||
| or get_cudnn_version() < 90700 | ||
| or cudnn_version < 90700 | ||
| ): | ||
| pytest.skip( | ||
| "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" | ||
|
|
@@ -430,7 +479,7 @@ def _check_configs(self): | |
| if not FusedAttnHelper.is_non_deterministic_allowed() and ( | ||
| self.dropout_prob != 0.0 | ||
| or self.attn_bias_type != AttnBiasType.NO_BIAS | ||
| or get_cudnn_version() < 91801 | ||
| or cudnn_version < 91801 | ||
| ): | ||
| pytest.skip( | ||
| "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" | ||
|
|
@@ -1474,6 +1523,36 @@ def test_backward( | |
| QKVLayout.THD_THD_THD, | ||
| id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", | ||
| ), | ||
| # D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel | ||
| # (cuDNN FE 1.24 / BE 9.23+). | ||
| pytest.param( | ||
| 4, | ||
| 128, | ||
| 128, | ||
| 16, | ||
| 16, | ||
| 256, | ||
| 256, | ||
| jnp.float16, | ||
| QKVLayout.BSHD_BS2HD, | ||
| id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", | ||
| ), | ||
| pytest.param( | ||
| 4, | ||
| 128, | ||
| 128, | ||
| 16, | ||
| 16, | ||
| 256, | ||
| 256, | ||
| jnp.float16, | ||
| QKVLayout.THD_T2HD, | ||
| id="4-128-128-16-16-256-256-FP16-SELF-RAGGED_KV_PACKED", | ||
| marks=pytest.mark.xfail( | ||
| reason="cuDNN 9.23 D=256 BWD currently does not build a THD execution plan.", | ||
| strict=True, | ||
| ), | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): | |||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| # cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), | ||||||||
| # via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, | ||||||||
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | ||||||||
| # (for non-causal masks) full-window attention. | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment block ends with a repeated phrase: line 383 (
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like some editing glitch :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 379242b |
||||||||
| model_configs_fused_hdim256 = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) | ||||||||
| "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), | ||||||||
| "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), | ||||||||
| # SWA is allowed only together with a causal mask on the D=256 bprop kernel. | ||||||||
| "fused_hd256_causal_swa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) | ||||||||
| ), | ||||||||
| # GQA variant (num_gqa_groups < num_heads). | ||||||||
| "fused_hd256_padding_causal_gqa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" | ||||||||
| ), | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| @pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") | ||||||||
| @pytest.mark.skipif( | ||||||||
| device_compute_capability not in ((10, 0), (10, 3)), | ||||||||
| reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", | ||||||||
| ) | ||||||||
| @pytest.mark.parametrize("dtype", param_types) | ||||||||
| @pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) | ||||||||
| @pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) | ||||||||
| def test_dpa_fused_attn_hdim256(dtype, model_configs, model): | ||||||||
| """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" | ||||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| model_configs_fa4_mla = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) | ||||||||
| "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |
| (head_dim_qk <= 256 && head_dim_v <= 256 && | ||
| ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || | ||
| (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || | ||
| // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged | ||
| (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && | ||
| sm_arch_ < 110 && cudnn_runtime_version >= 92300 && | ||
| layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && | ||
| // The FE forces this path onto the deterministic bprop algorithm, which on | ||
| // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). | ||
| bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && | ||
| softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && | ||
| // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. | ||
| ((window_size_left == -1 && window_size_right == -1) || | ||
| ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && | ||
| (window_size_right == -1 || window_size_right == 0)))) || | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RE: THD support
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fixed in a264de1
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new condition only excludes
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This support is forwarding looking, i.e., support is added for THD and BSHD in the TE common fused attn backend checking code, however, the PR is still waiting on cuDNN to fix support for THD.
|
||
| // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 | ||
| (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && | ||
| layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in
fused_attn.cpprequiresbias_type == NVTE_NO_BIASfor the new D=256 BWD path, meaning any config withattn_bias_type != NO_BIAS && bias_shape != _1HSSwill silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in b8fe919