Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ def setup_jax_extension(

# Header files
include_dirs = get_cuda_include_dirs()
cudnn_frontend_include_dir = None
for base_path in (Path(common_header_files), *Path(common_header_files).parents):
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is None:
for base_path in Path(__file__).resolve().parents:
candidate = base_path / "3rdparty" / "cudnn-frontend" / "include"
if candidate.exists():
cudnn_frontend_include_dir = candidate
break
if cudnn_frontend_include_dir is not None:
include_dirs.append(cudnn_frontend_include_dir)
include_dirs.extend(
[
common_header_files,
Expand Down
1 change: 1 addition & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_score_mod.xml $TE_PATH/tests/jax/test_fused_attn_score_mod.py || test_fail "tests/jax/test_fused_attn_score_mod.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
Expand Down
147 changes: 145 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs_for_attn,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from test_fused_attn import (
FusedAttnRunner,
BiasShape,
SeqDescFormat,
customcall_fused_dpa,
)
from test_fused_attn_score_mod import (
_ScoreModSoftcap,
_has_cudnn_frontend_python,
_reference_attention,
_require_cudnn_frontend_score_mod,
)
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax import autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand Down Expand Up @@ -272,6 +285,136 @@ def test_cross_attn(
runner.test_backward()


DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
}
Comment thread
vcherepanov-nv marked this conversation as resolved.


@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required")
class TestDistributedScoreModSelfAttn:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_softcap_score_mod_with_aux_params_backward(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
):
_require_cudnn_frontend_score_mod()
batch, seqlen, num_heads, head_dim = data_shape
Comment thread
vcherepanov-nv marked this conversation as resolved.
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource

if dp_axis is not None:
dp_size = mesh_shape[mesh_axes.index(dp_axis)]
if batch % dp_size != 0:
pytest.skip(f"{batch=} must be divisible by {dp_size=}")
if tp_axis is not None:
tp_size = mesh_shape[mesh_axes.index(tp_axis)]
if num_heads % tp_size != 0:
pytest.skip(f"{num_heads=} must be divisible by {tp_size=}")

runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_heads,
num_heads,
head_dim,
head_dim,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
AttnSoftmaxType.VANILLA_SOFTMAX,
0.0,
dtype,
True,
QKVLayout.BSHD_BSHD_BSHD,
None,
None,
SeqDescFormat.Mask,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
)
runner._setup_inputs()

Comment on lines +322 to +346
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems like you are using the runner to only setup the inputs but then are following that up with "duplicate" code that test_forward and test_backward in test_fused_attn.py
The suggestions was to try and use the runner to call forward(), which does the setup using the runner and also runs the test.

The idea is if you can integrate the non distributed tests with the Runner infrastrucutre then the distributed tests can directly use it here for free. The approach here seems to be somewhere in between your older approach and the suggested approach. You can refer to other tests in this file for reference, incase I've not done a good job explaining

i'm curiosu to know if the reason for that was because you were unable to fully integrate the score mod tests into the Fused Attn runner ? Because it seems like it is the same reason for creating a separate test_fused_attn_score_mod.py as compared to integrating the flex attn tests in the Fused attn runner in test_fused_attn.py

If test_fused_attn_score_mod must be created then a Runner should be ceated in the too and the distributed flex attn tests can then use that runner in here (similar to how other tests do)

qkv_sharding = NamedSharding(runner.mesh, PartitionSpec(dp_axis, None, tp_axis, None))
query = (0.125 * runner.q).astype(dtype)
key_tensor = (0.125 * runner.k).astype(dtype)
value = (0.125 * runner.v).astype(dtype)
doutput = random.normal(random.PRNGKey(2025), data_shape, dtype=dtype)
Comment on lines +348 to +351
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this ? 0.125*


scaling_factor = runner.scaling_factor
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = customcall_fused_dpa(
q,
k,
v,
None,
None,
None,
None,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.NO_MASK,
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor=scaling_factor,
dropout_probability=0.0,
is_training=True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
Comment on lines +347 to +403
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can come for free if the flex attn is integrated with the FusedAttn Runner

with runner.mesh, autocast(mesh_resource=mesh_resource):
(score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args)
(ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput)

assert score_mod_out.sharding == qkv_sharding
for grad in score_mod_grads:
assert grad.sharding == qkv_sharding

assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2)
assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2)
for grad, ref_grad in zip(score_mod_grads, ref_grads):
assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2)


DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
Expand Down
Loading
Loading