-
Notifications
You must be signed in to change notification settings - Fork 735
[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7b2e9f8
a78c4c1
546efcf
226156c
d177ecf
fe1bbd0
793ee2f
45b332f
e317f99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -444,6 +444,56 @@ def _check_configs(self): | |
| "is either BSHD_BSHD_BSHD or THD_THD_THD" | ||
| ) | ||
|
|
||
| # D=256 bprop on SM10.x uses cuDNN's dedicated SDPA bprop kernel | ||
| # (cuDNN FE 1.24 / BE 9.23+). FE forces this path onto the deterministic algorithm path, | ||
| # which rejects dBias, dropout, and ALiBi. It supports vanilla softmax only and allows SWA | ||
| # together with a causal mask only. | ||
| compute_capability = get_device_compute_capability(0) | ||
| is_sm10x = 100 <= compute_capability < 110 | ||
| if self.is_training and is_sm10x and (self.head_dim_qk == 256 or self.head_dim_v == 256): | ||
| if self.head_dim_qk != 256 or self.head_dim_v != 256: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell only supports d_qk == d_v == 256;" | ||
| f" got d_qk={self.head_dim_qk}, d_v={self.head_dim_v}." | ||
| ) | ||
| cudnn_version = get_cudnn_version() | ||
| if cudnn_version < 92300: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell requires cuDNN 9.23 or newer;" | ||
| f" got cuDNN {cudnn_version}." | ||
| ) | ||
| # Non-learnable bias is fine (bias is allowed as an input); only dBias is | ||
| # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] | ||
| # (see test_backward), so gate on that. | ||
| unsupported = None | ||
| if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | ||
| unsupported = "pre-scale bias" | ||
| elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: | ||
| unsupported = ( | ||
| "bias gradients (dBias); frozen/non-learnable bias inputs" | ||
| " (i.e. non-1HSS bias shapes) are supported" | ||
| ) | ||
| elif self.dropout_prob != 0.0: | ||
| unsupported = "dropout" | ||
| elif self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: | ||
| unsupported = "non-vanilla softmax" | ||
| if unsupported is not None: | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell uses the deterministic SM100 D=256 SDPA BWD" | ||
| f" kernel which does not support {unsupported}." | ||
| ) | ||
| if self.window_size is not None and self.window_size != (-1, -1): | ||
| if not self.attn_mask_type.is_causal(): | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell uses the SM100 D=256 SDPA BWD kernel" | ||
| " which requires window_size=(-1, -1) for non-causal masks." | ||
| ) | ||
| if self.window_size[1] not in (-1, 0): | ||
| pytest.skip( | ||
| "D=256 BWD on Blackwell only supports right window -1 or 0" | ||
| " for causal masks." | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't these checks duplicate to the checks we added on the C++ side? Would the call |
||
| self.backend = FusedAttnHelper( | ||
| self.is_training, | ||
| self.dtype, | ||
|
|
@@ -1474,6 +1524,21 @@ def test_backward( | |
| QKVLayout.THD_THD_THD, | ||
| id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", | ||
| ), | ||
| # D=256 deterministic backward on the SM100 dedicated SDPA bprop kernel | ||
| # (cuDNN FE 1.24 / BE 9.23+). Unsupported configs (e.g. dBias, non-256 head dims) | ||
| # are skipped by FusedAttnRunner._check_configs. | ||
| pytest.param( | ||
| 4, | ||
| 128, | ||
| 128, | ||
| 16, | ||
| 16, | ||
| 256, | ||
| 256, | ||
| jnp.float16, | ||
| QKVLayout.BSHD_BS2HD, | ||
| id="4-128-128-16-16-256-256-FP16-SELF-KV_PACKED", | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -377,6 +377,38 @@ def test_dpa_fa4_hdim256(dtype, model_configs, model): | |||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| # cuDNN FusedAttention D=256 bprop is supported on sm10x from cuDNN 9.23 (FE 1.24), | ||||||||
| # via the dedicated deterministic SDPA bprop kernel, which supports d_qk == d_v == 256 only, | ||||||||
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | ||||||||
| # (for non-causal masks) full-window attention. | ||||||||
|
Comment on lines
+382
to
+383
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment block ends with a repeated phrase: line 383 (
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like some editing glitch :) |
||||||||
| model_configs_fused_hdim256 = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk) -> head_dim_v defaults to head_dim_qk (256) | ||||||||
| "fused_hd256_no_mask": ModelConfig(2, 512, 16, 256), | ||||||||
| "fused_hd256_padding": ModelConfig(2, 512, 16, 256, attn_mask_type="padding"), | ||||||||
| # SWA is allowed only together with a causal mask on the D=256 bprop kernel. | ||||||||
| "fused_hd256_causal_swa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, attn_mask_type="causal", window_size=(128, 0) | ||||||||
| ), | ||||||||
| # GQA variant (num_gqa_groups < num_heads). | ||||||||
| "fused_hd256_padding_causal_gqa": ModelConfig( | ||||||||
| 2, 1024, 16, 256, num_gqa_groups=4, attn_mask_type="padding_causal" | ||||||||
| ), | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| @pytest.mark.skipif(get_cudnn_version() < (9, 23, 0), reason="cuDNN 9.23+ is required.") | ||||||||
| @pytest.mark.skipif( | ||||||||
| device_compute_capability not in ((10, 0), (10, 3)), | ||||||||
| reason="cuDNN FusedAttention head_dim=256 backward is Blackwell server (SM100/SM103) only.", | ||||||||
| ) | ||||||||
| @pytest.mark.parametrize("dtype", param_types) | ||||||||
| @pytest.mark.parametrize("model_configs", [model_configs_fused_hdim256]) | ||||||||
| @pytest.mark.parametrize("model", model_configs_fused_hdim256.keys()) | ||||||||
| def test_dpa_fused_attn_hdim256(dtype, model_configs, model): | ||||||||
| """Test DotProductAttention with cuDNN FusedAttention: head_dim=256 backward on Blackwell""" | ||||||||
| test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) | ||||||||
|
|
||||||||
|
|
||||||||
| model_configs_fa4_mla = { | ||||||||
| # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) | ||||||||
| "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -315,6 +315,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |
| (head_dim_qk <= 256 && head_dim_v <= 256 && | ||
| ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || | ||
| (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || | ||
| // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged | ||
| (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && | ||
| sm_arch_ < 110 && cudnn_runtime_version >= 92300 && | ||
| layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && | ||
| // The FE forces this path onto the deterministic bprop algorithm, which on | ||
| // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). | ||
| bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && | ||
| softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && | ||
| // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. | ||
| ((window_size_left == -1 && window_size_right == -1) || | ||
| ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && | ||
| (window_size_right == -1 || window_size_right == 0)))) || | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only. |
||
| // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 | ||
| (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && | ||
| layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in
fused_attn.cpprequiresbias_type == NVTE_NO_BIASfor the new D=256 BWD path, meaning any config withattn_bias_type != NO_BIAS && bias_shape != _1HSSwill silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.