-
Notifications
You must be signed in to change notification settings - Fork 732
[JAX] Support for cuDNN-backed flex attention #2985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f967a26
6b05328
1a96352
3bf9e97
29bbac7
c597af5
f8bd844
2c01c5e
ba6a1a7
deebf8e
9a92dd2
2198a79
ffc8c79
8856323
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,13 +7,26 @@ | |
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import random | ||
| from jax.sharding import NamedSharding, PartitionSpec | ||
| from distributed_test_base import ( | ||
| generate_configs, | ||
| generate_context_parallel_configs_for_attn, | ||
| generate_collectives_count, | ||
| ) | ||
| from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat | ||
| from utils import pytest_parametrize_wrapper | ||
| from test_fused_attn import ( | ||
| FusedAttnRunner, | ||
| BiasShape, | ||
| SeqDescFormat, | ||
| customcall_fused_dpa, | ||
| ) | ||
| from test_fused_attn_score_mod import ( | ||
| _ScoreModSoftcap, | ||
| _has_cudnn_frontend_python, | ||
| _reference_attention, | ||
| _require_cudnn_frontend_score_mod, | ||
| ) | ||
| from utils import assert_allclose, pytest_parametrize_wrapper | ||
| from transformer_engine.jax import autocast | ||
| from transformer_engine.jax.attention import ( | ||
| is_fused_attn_kernel_available, | ||
| AttnBiasType, | ||
|
|
@@ -272,6 +285,136 @@ def test_cross_attn( | |
| runner.test_backward() | ||
|
|
||
|
|
||
| DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { | ||
| "L0": [], | ||
| "L1": [(4, 16, 4, 64)], | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") | ||
| class TestDistributedScoreModSelfAttn: | ||
| @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) | ||
| @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| def test_softcap_score_mod_with_aux_params_backward( | ||
| self, | ||
| device_count, | ||
| mesh_shape, | ||
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| dtype, | ||
| ): | ||
| _require_cudnn_frontend_score_mod() | ||
| batch, seqlen, num_heads, head_dim = data_shape | ||
|
vcherepanov-nv marked this conversation as resolved.
|
||
| dp_axis = mesh_resource.dp_resource | ||
| tp_axis = mesh_resource.tpsp_resource | ||
|
|
||
| if dp_axis is not None: | ||
| dp_size = mesh_shape[mesh_axes.index(dp_axis)] | ||
| if batch % dp_size != 0: | ||
| pytest.skip(f"{batch=} must be divisible by {dp_size=}") | ||
| if tp_axis is not None: | ||
| tp_size = mesh_shape[mesh_axes.index(tp_axis)] | ||
| if num_heads % tp_size != 0: | ||
| pytest.skip(f"{num_heads=} must be divisible by {tp_size=}") | ||
|
|
||
| runner = FusedAttnRunner( | ||
| batch, | ||
| seqlen, | ||
| seqlen, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim, | ||
| AttnBiasType.NO_BIAS, | ||
| AttnMaskType.NO_MASK, | ||
| AttnSoftmaxType.VANILLA_SOFTMAX, | ||
| 0.0, | ||
| dtype, | ||
| True, | ||
| QKVLayout.BSHD_BSHD_BSHD, | ||
| None, | ||
| None, | ||
| SeqDescFormat.Mask, | ||
| number_of_devices=device_count, | ||
| mesh_shape=mesh_shape, | ||
| mesh_axes=mesh_axes, | ||
| mesh_resource=mesh_resource, | ||
| ) | ||
| runner._setup_inputs() | ||
|
|
||
|
Comment on lines
+322
to
+346
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it seems like you are using the runner to only setup the inputs but then are following that up with "duplicate" code that The idea is if you can integrate the non distributed tests with the Runner infrastrucutre then the distributed tests can directly use it here for free. The approach here seems to be somewhere in between your older approach and the suggested approach. You can refer to other tests in this file for reference, incase I've not done a good job explaining i'm curiosu to know if the reason for that was because you were unable to fully integrate the score mod tests into the Fused Attn runner ? Because it seems like it is the same reason for creating a separate test_fused_attn_score_mod.py as compared to integrating the flex attn tests in the Fused attn runner in test_fused_attn.py If test_fused_attn_score_mod must be created then a Runner should be ceated in the too and the distributed flex attn tests can then use that runner in here (similar to how other tests do) |
||
| qkv_sharding = NamedSharding(runner.mesh, PartitionSpec(dp_axis, None, tp_axis, None)) | ||
| query = (0.125 * runner.q).astype(dtype) | ||
| key_tensor = (0.125 * runner.k).astype(dtype) | ||
| value = (0.125 * runner.v).astype(dtype) | ||
| doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype) | ||
|
Comment on lines
+348
to
+351
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we do this ? 0.125* |
||
|
|
||
| scaling_factor = runner.scaling_factor | ||
| softcap = 0.8 | ||
| softcap_score_mod = _ScoreModSoftcap() | ||
|
|
||
| def score_mod_loss(q, k, v, dout): | ||
| out = customcall_fused_dpa( | ||
| q, | ||
| k, | ||
| v, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| attn_bias_type=AttnBiasType.NO_BIAS, | ||
| attn_mask_type=AttnMaskType.NO_MASK, | ||
| qkv_layout=QKVLayout.BSHD_BSHD_BSHD, | ||
| softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, | ||
| scaling_factor=scaling_factor, | ||
| dropout_probability=0.0, | ||
| is_training=True, | ||
| score_mod=softcap_score_mod.forward, | ||
| score_mod_bprop=softcap_score_mod.backward, | ||
| score_mod_tensors={"softcap": softcap}, | ||
| score_mod_bprop_tensors={"softcap": softcap}, | ||
| ) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| def ref_loss(q, k, v, dout): | ||
| out = _reference_attention(q, k, v, scaling_factor, softcap=softcap) | ||
| loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32)) | ||
| return loss, out | ||
|
|
||
| jitted_score_mod = jax.jit( | ||
| jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True), | ||
| in_shardings=( | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| qkv_sharding, | ||
| ), | ||
| out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)), | ||
| ) | ||
| jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)) | ||
|
|
||
| sharded_args = ( | ||
| jax.device_put(query, qkv_sharding), | ||
| jax.device_put(key_tensor, qkv_sharding), | ||
| jax.device_put(value, qkv_sharding), | ||
| jax.device_put(doutput, qkv_sharding), | ||
| ) | ||
|
Comment on lines
+347
to
+403
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of this can come for free if the flex attn is integrated with the FusedAttn Runner |
||
| with runner.mesh, autocast(mesh_resource=mesh_resource): | ||
| (score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args) | ||
| (ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput) | ||
|
|
||
| assert score_mod_out.sharding == qkv_sharding | ||
| for grad in score_mod_grads: | ||
| assert grad.sharding == qkv_sharding | ||
|
|
||
| assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) | ||
| assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) | ||
| for grad, ref_grad in zip(score_mod_grads, ref_grads): | ||
| assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) | ||
|
|
||
|
|
||
| DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [ | ||
| pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), | ||
| pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.