Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs_for_attn,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from test_fused_attn import (
FusedAttnRunner,
BiasShape,
SeqDescFormat,
_ScoreModSoftcap,
_has_cudnn_frontend_python,
_reference_attention,
_require_cudnn_frontend_score_mod,
)
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax import autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand All @@ -25,6 +36,7 @@
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
fused_attn,
Comment thread
vcherepanov-nv marked this conversation as resolved.
Outdated
)


Expand Down Expand Up @@ -272,6 +284,114 @@ def test_cross_attn(
runner.test_backward()


DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [(4, 16, 4, 64)],
}
Comment thread
vcherepanov-nv marked this conversation as resolved.


@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required")
class TestDistributedScoreModSelfAttn:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_softcap_score_mod_with_aux_params_backward(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
):
_require_cudnn_frontend_score_mod()
batch, seqlen, num_heads, head_dim = data_shape
Comment thread
vcherepanov-nv marked this conversation as resolved.
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource

if dp_axis is not None:
dp_size = mesh_shape[mesh_axes.index(dp_axis)]
if batch % dp_size != 0:
pytest.skip(f"{batch=} must be divisible by {dp_size=}")
if tp_axis is not None:
tp_size = mesh_shape[mesh_axes.index(tp_axis)]
if num_heads % tp_size != 0:
pytest.skip(f"{num_heads=} must be divisible by {tp_size=}")

devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None))

key = random.PRNGKey(2025)
q_key, k_key, v_key, dout_key = random.split(key, 4)
query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype)
key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype)
value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype)
doutput = random.normal(dout_key, data_shape, dtype=dtype)

scaling_factor = head_dim**-0.5
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = fused_attn(
(q, k, v),
None,
None,
None,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
QKVLayout.BSHD_BSHD_BSHD,
AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor,
0.0,
True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
with mesh, autocast(mesh_resource=mesh_resource):
(score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args)
(ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput)

assert score_mod_out.sharding == qkv_sharding
for grad in score_mod_grads:
assert grad.sharding == qkv_sharding

assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2)
assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2)
for grad, ref_grad in zip(score_mod_grads, ref_grads):
assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2)


DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
Expand Down
Loading
Loading