Skip to content

[PyTorch] Support for cuDNN-backed flex attention#2984

Open
vcherepanov-nv wants to merge 20 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3
Open

[PyTorch] Support for cuDNN-backed flex attention#2984
vcherepanov-nv wants to merge 20 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

@vcherepanov-nv vcherepanov-nv commented May 13, 2026

Description

Adds experimental PyTorch support for cuDNN-backed flex attention in DotProductAttention via a new score_mod callback path.

Users can pass:

  • score_mod(graph, score, tensors) -> score for forward score modification
  • optional score_mod_bprop(graph, dP, tensors) -> dP for backward
  • optional runtime tensor dictionaries for forward/backward score-mod graph inputs

When score_mod_bprop is supplied, it is the user's responsibility to make it mathematically consistent with score_mod. TE forwards this callback to cuDNN as provided and does not derive or validate the backward score transformation automatically.

Supported score_mod configuration

The current cuDNN-backed Flex Attention path supports:

  • PyTorch DotProductAttention / FusedAttention
  • FP16 or BF16 unquantized torch.Tensor Q/K/V inputs
  • SBHD or BSHD Q/K/V layouts
  • cuDNN F16/BF16 arbitrary-seqlen fused attention backend
  • attn_mask_type="no_mask"
  • core_attention_bias_type="no_bias" with no explicit bias tensor
  • vanilla softmax
  • attention_dropout=0.0
  • num_splits=1

The path is currently not supported with FP8, fp8_output, THD format, explicit cu_seqlens inputs, pad_between_seqs, attention masks, attention bias, ALiBi, sliding-window attention, sink attention, dropout, KV cache, context parallelism, CUDA graph capture, checkpointed core attention, or return_max_logit.

For deterministic execution, TE passes the deterministic setting through backend selection and forwards it to cuDNN Frontend sdpa_backward as use_deterministic_algorithm. The score_mod forward sdpa call does not take a separate deterministic flag.

Fixes #2492.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Adds FusedAttentionWithScoreModFunc, a cuDNN frontend Python graph path for SDPA forward/backward with score_mod and score_mod_bprop.
  • Extends DotProductAttention / FusedAttention APIs with score_mod, score_mod_bprop, score_mod_tensors, and score_mod_bprop_tensors.
  • Adds backend-selection filtering so score_mod only selects supported cuDNN fused attention configurations.
  • Adds execution-plan caching for forward and backward score-mod graphs, keyed by tensor metadata, layout, scale, callback topology, and runtime tensor metadata.
  • Supports explicit score_mod_graph_cache_key() for stateful callbacks, while leaving unsafe unkeyed bound methods uncached.
  • Executes cuDNN graphs on PyTorch's current CUDA stream and preserves SBHD/BSHD layouts without extra BHSD copies.
  • Adds validation for unsupported combinations including FP8, context parallelism, THD, KV cache, explicit masks/biases, dropout, non-vanilla softmax, CUDA graph capture, and checkpointed core attention.
  • Adds tests for cache-key behavior, unsafe callback caching, runtime tensor version checking, and CUDA correctness cases covering causal masking, softcap, and post-scale-bias-style score modification.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds experimental PyTorch support for cuDNN-backed flex attention in DotProductAttention via a new score_mod callback path, routing through a new FusedAttentionWithScoreModFunc autograd function that builds and caches cuDNN frontend Python graphs for both forward and backward passes.

  • New flex_attention.py implements cuDNN graph building, a two-level execution-plan cache keyed on tensor metadata and callback topology, and FusedAttentionWithScoreModFunc with version-counter-based in-place mutation detection for runtime tensors.
  • Backend selection in utils.get_attention_backend filters out Flash Attention and Unfused backends for the score_mod path and gates FusedAttention on the F16_arbitrary_seqlen sub-backend, with explicit unsupported-combination logging.
  • DotProductAttention and FusedAttention are extended with four new parameters (score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors) guarded by thorough input validation; a new AttentionRuntimeFlags dataclass decouples these runtime-only flags from the static AttentionParams used for backend caching.

Confidence Score: 5/5

The change is well-structured and the experimental feature is gated behind multiple layers of validation; the one open gap is that CPU tensors in score_mod_tensors reach cuDNN before being rejected.

The score_mod path is thoroughly guarded at backend selection and at the FusedAttention assertion layer. The cache key design correctly handles lambdas, bound methods, and stateful callbacks. The autograd function correctly separates training/inference save paths and uses version counters for mutation detection. The only finding is a missing device check on score_mod_tensors dict values that would surface as a cryptic cuDNN error rather than a clear assertion failure — this affects error clarity but not correctness or safety.

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py (score_mod_tensors device validation) and tests/pytorch/attention/test_flex_attention.py (CPU tensors in CUDA correctness test fixtures)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/flex_attention.py New file implementing cuDNN-backed flex attention with forward/backward graph caching, score_mod callback wrapping, and FusedAttentionWithScoreModFunc autograd function. Cache key design is thoughtful; main open concern is the backward graph using torch.empty_like allocations purely for shape extraction.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds score_mod parameters to FusedAttention.forward and routes the score_mod path to FusedAttentionWithScoreModFunc, with validation asserting required backend, no-mask, no-bias, no-dropout, and no-CP constraints.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Adds score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors parameters to DotProductAttention.forward; validation checks types and dependencies, but does not validate that tensors reside on CUDA.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Introduces AttentionRuntimeFlags dataclass and get_attention_backend score_mod filter that disables Flash/Unfused backends and gates FusedAttention on F16_arbitrary_seqlen sub-backend; also fixes fp8_meta None-dict comparison.
tests/pytorch/attention/test_flex_attention.py New test file covering cache-key semantics, unsafe unkeyed callback caching, version-counter backward regression, and CUDA correctness tests; score_mod runtime tensors in CUDA correctness tests are created on CPU and will cause device-mismatch errors at cuDNN execution time.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention.forward
    participant BAB as get_attention_backend
    participant FA as FusedAttention.forward
    participant FASMF as FusedAttentionWithScoreModFunc
    participant Cache as _cudnn_score_mod_graph_cache
    participant cuDNN as cuDNN Frontend

    User->>DPA: "forward(q, k, v, score_mod=fn, score_mod_tensors={...})"
    DPA->>DPA: Validate score_mod inputs
    DPA->>BAB: "AttentionParams(runtime_flags.has_score_mod=True)"
    BAB->>BAB: "Disable FlashAttn & Unfused backends"
    BAB->>BAB: Check unsupported combos (FP8, THD, dropout, etc.)
    BAB->>BAB: Filter to F16_arbitrary_seqlen only
    BAB-->>DPA: "use_fused_attention=True, backend=F16_arbitrary_seqlen"
    DPA->>FA: "forward(score_mod=fn, score_mod_tensors={...})"
    FA->>FA: Assert no FP8/mask/bias/dropout
    FA->>FASMF: apply(is_training, q, k, v, score_mod, ...)
    FASMF->>Cache: _get_cudnn_score_mod_fwd_graph(key)
    alt Cache miss
        Cache->>cuDNN: Build pygraph (q, k, v, score_mod callback)
        cuDNN-->>Cache: _CudnnScoreModFwdGraphEntry
    end
    Cache-->>FASMF: entry
    FASMF->>cuDNN: execute(variant_pack, workspace)
    cuDNN-->>FASMF: output tensor
    FASMF->>FASMF: "ctx.save_for_backward(q, k, v, out, stats, *tensors)"
    FASMF-->>User: output
    User->>FASMF: backward(d_out)
    FASMF->>Cache: _get_cudnn_score_mod_bwd_graph(key)
    alt Cache miss
        Cache->>cuDNN: Build sdpa_backward graph (score_mod + score_mod_bprop)
        cuDNN-->>Cache: _CudnnScoreModBwdGraphEntry
    end
    FASMF->>cuDNN: execute(dq/dk/dv variant_pack)
    cuDNN-->>FASMF: dq, dk, dv
    FASMF-->>User: (None, dq, dk, dv, ...)
Loading

Reviews (10): Last reviewed commit: "Skip softcap flex attention tests before..." | Re-trigger Greptile

Comment on lines +1273 to +1281
def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]:
"""Create a stable cache key for a score_mod callable."""
if callback is None:
return None
self_obj = getattr(callback, "__self__", None)
func_obj = getattr(callback, "__func__", None)
if self_obj is not None and func_obj is not None:
return ("bound_method", id(self_obj), id(func_obj))
return ("callable", id(callback))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 id()-based cache key is unsafe for parameterized bound-method score_mods

id(self_obj) identifies a Python object by its memory address. When a bound-method instance is garbage-collected, Python may immediately reuse that memory for a new instance. If the new instance belongs to the same class (same id(func_obj)), the cache key is identical, so _get_cudnn_score_mod_fwd_graph returns the old compiled graph even though the new instance might construct a structurally different computation — e.g., a score_mod class whose forward loops self.n_layers times. The wrong graph is executed without any error, silently producing incorrect attention outputs.

For stateless module-level functions this is fine (they're never GC'd), but any stateful class-based score_mod where different instances produce different graph topologies can hit this bug in long-running programs. Consider using type(self_obj) and a per-class sequence counter, or requiring callers to provide an explicit cache key.

Comment on lines +1556 to +1563
fused_attention_backend = tex.get_fused_attn_backend(
self.training,
q_type,
q_type,
dpa_utils.QKVLayout["bshd_bshd_bshd"],
dpa_utils.AttnBiasType["no_bias"],
dpa_utils.AttnMaskType["no_mask"],
dpa_utils.SoftmaxType["vanilla"],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 get_fused_attn_backend availability check always uses bshd_bshd_bshd regardless of actual format

The score_mod path hard-codes dpa_utils.QKVLayout["bshd_bshd_bshd"] for the backend probe, even when the user passes qkv_format="sbhd". The result is only used to gate on NVTE_No_Backend, so in practice it likely works today because backend availability for a given dtype is layout-independent. However, if a future cuDNN version makes SBHD/BSHD support diverge, this probe would give a false-positive (accepts sbhd even though no backend supports it) or false-negative (rejects sbhd when it is actually supported). Using the real layout for the probe would make the check self-documenting and future-proof.

)

if context_parallel:
if score_mod is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be in the else branch, because it doesn't support context parallelism. Something like this:
if context_parallel: elif score_mod is not None: else:

raise ValueError(
"score_mod requires a cuDNN FusedAttention backend, but no fused "
"attention backend supports the provided inputs."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the score_mod path, I don't think we need to call tex.get_fused_attn_backend() and check if it's supported or not. If anything, we should add graph.validate() -> .... graph.build_plans() to dpa_utils.get_attention_backend(attention_params), but if that's too heavy-handed, we can only do the checks you had above (the asserts). Once those checks were added to dpa_utils.get_attention_backend, whether FusedAttention backend is run or not will be controlled by the following logic (just like with non-score_mod cases):

(
                        use_flash_attention,
                        flash_attention_backend,
                        use_fused_attention,
                        fused_attention_backend,
                        use_unfused_attention,
                        _,
                    ) = dpa_utils.get_attention_backend(attention_params)


def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device):
"""Create a cuDNN frontend Python graph for F16/BF16 SDPA."""
import cudnn # pylint: disable=import-outside-toplevel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you import the cudnn from 3rdparty/cudnn-frontend, instead of from the environment/system-wide installation? We have control over the version in 3rdparty/cudnn-frontend, but not the system one.

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("scalar_loss", [False, True])
def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would @pytest.mark.parameterize("score_mod", ["causal", "softcap", "post_scale_bias"]) simplify the tests a bit, so that we don't have 3 separate tests with a lot of repeated code?

score_mod: Callable,
score_mod_tensors: Optional[Dict[str, torch.Tensor]],
output_layer: torch.Tensor,
stats_bhs1: Optional[torch.Tensor],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just call this stats, even though it might only support bhs1 shape right now. On the C++ side, cuDNN does support th1 (for THD format) as well. Could we leave the name generic for now in case we want to add more support to it in the future?

return output.contiguous()


def _bhsd_dim_stride(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a lot of small utility functions here - is there a way to pack them up a bit or group them in some way, so the code is easier to read? I know this is Python and we probably do need more than 2 functions (fwd+bwd) but could you please have a look into this? Thanks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this and was my first thought too.
We should club these function into a couple classes that can sit in this file at the very least.

However, I think this approach is still not the right approach. We should have a separate flex_attention.py file similar to context_parallel.py and backends.py can import it similar to how it imports the CP functions right now.
I strongly recommend this for two reasons :

  1. When we refactored attention as a whole early last year, the idea was to modularize attention. That was the reason CP was moved out of attention. With Flex attention's functionality and code in here being fairly decoupled from vanilla DPA, it should be easier to move it out. Leaving this code in here would add ~1000 lines of code that is not related to the vanilla DPA and would practically be undoing the refactoring work we did early last year. The same reason for moving CP to it's own file should also apply to Flex attention.
  2. A developer/user of TE PyT DPA should not have to worry about the details of flex attention. Similarly someones modifying flex should not be bogged down by the details of vanilla fused attn. Hence, decoupling is important to aid with debugging as well

)


def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just call this "post_scale_bias" to be consistent with our nomenclature elsewhere.

vcherepanov-nv and others added 3 commits May 15, 2026 00:48
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +1368 to +1373
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Module-level lambdas all share the same __qualname__ = "<lambda>", so two different lambdas defined at module scope in the same file (e.g., sm1 = lambda g, s, t: s and sm2 = lambda g, s, t: g.neg(input=s)) would produce the identical cache key ("function", module, "<lambda>"). The second lambda would silently reuse the compiled graph from the first, computing wrong attention scores with no error. Named module-level functions are safe because their qualnames are unique, but lambdas are not. Excluding <lambda> from the cacheable path makes them _SCORE_MOD_UNCACHEABLE, which builds a fresh graph every call — the same safe fallback already used for closures and nested functions.

Suggested change
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
and "<lambda>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +1968 to +1973
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)},
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The neg_inf and zero tensors are created on CPU (torch.full defaults to CPU), but the attention computation runs on CUDA. When cuDNN executes the graph it calls into CUDA kernels and expects all variant-pack tensors to reside on the compute device. Passing CPU tensors here will produce a device-mismatch error at graph execution time, causing both the "causal" test cases to fail.

Suggested change
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)},
}
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9, device="cuda")},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0, device="cuda")},
}

Comment on lines +269 to +272
score_mod : bool, default = False
Whether a score_mod callback was provided.
score_mod_bprop : bool, default = False
Whether a score_mod bprop callback was provided.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If this is a bool, to match has_attention_mask, consider has_score_mod and has_score_mod_bprop instead ?

logger.debug("Disabling all backends for max_logit with FP8 attention")

# Filter: score_mod
if score_mod_bprop and not score_mod:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens (is expected to happen) if score_mod_bprop=False and score_mod=True ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a perfectly legal case, if, for instance, score_mod is used only for masking.

Comment on lines +681 to +686
use_flash_attention = False
use_flash_attention_2 = False
use_flash_attention_3 = False
use_flash_attention_4 = False
use_fused_attention = False
use_unfused_attention = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Outside the scope of this PR but would be good to do in this or subsequent PR: having a function or something similar for when performing an action/query on all flash_attention vars

Comment on lines +688 to +693
if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4:
logger.debug("Disabling FlashAttention for score_mod")
use_flash_attention = False
use_flash_attention_2 = False
use_flash_attention_3 = False
use_flash_attention_4 = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider this maybe ?:

Suggested change
if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4:
logger.debug("Disabling FlashAttention for score_mod")
use_flash_attention = False
use_flash_attention_2 = False
use_flash_attention_3 = False
use_flash_attention_4 = False
if use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4:
logger.debug("Disabling FlashAttention for score_mod")
use_flash_attention = False
use_flash_attention_2 = False
use_flash_attention_3 = False
use_flash_attention_4 = False

unless there's a good reason to do otherwise ?

Comment on lines +694 to +696
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for score_mod")
use_unfused_attention = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider this maybe ?

Suggested change
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for score_mod")
use_unfused_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for score_mod")
use_unfused_attention = False

unless there's a good reason to do otherwise ?

)
global _attention_backends
if is_in_onnx_export_mode():
if is_in_onnx_export_mode() and score_mod is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary here if dpa_utils.get_attention_backend(attention_params) does get called in the else block below ?
The flash, fused, unfused would be set in there anyways rgiht ?
Or am I missing something ?
cc: @cyanguwa

return output.contiguous()


def _bhsd_dim_stride(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this and was my first thought too.
We should club these function into a couple classes that can sit in this file at the very least.

However, I think this approach is still not the right approach. We should have a separate flex_attention.py file similar to context_parallel.py and backends.py can import it similar to how it imports the CP functions right now.
I strongly recommend this for two reasons :

  1. When we refactored attention as a whole early last year, the idea was to modularize attention. That was the reason CP was moved out of attention. With Flex attention's functionality and code in here being fairly decoupled from vanilla DPA, it should be easier to move it out. Leaving this code in here would add ~1000 lines of code that is not related to the vanilla DPA and would practically be undoing the refactoring work we did early last year. The same reason for moving CP to it's own file should also apply to Flex attention.
  2. A developer/user of TE PyT DPA should not have to worry about the details of flex attention. Similarly someones modifying flex should not be bogged down by the details of vanilla fused attn. Hence, decoupling is important to aid with debugging as well

Comment on lines +105 to +115
def _import_cudnn_frontend():
"""Import the vendored cuDNN frontend if built, otherwise use the installed package."""
cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH)
cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn"
if (
any(cudnn_frontend_package.glob("_compiled_module*"))
and cudnn_frontend_path not in sys.path
):
sys.path.insert(0, cudnn_frontend_path)
return importlib.import_module("cudnn")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?:

def _import_cudnn_frontend():
    cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH)
    cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn"
    if (
        any(cudnn_frontend_package.glob("_compiled_module*"))
        and cudnn_frontend_path not in sys.path
    ):
        sys.path.insert(0, cudnn_frontend_path)
        return importlib.import_module("cudnn")

    # Fall back
    if importlib.util.find_spec("cudnn") is not None:
        return importlib.import_module("cudnn")

    # Fail with a  message
    raise ImportError(
        "cuDNN Frontend Python package not found. "
        "Install it with: pip install nvidia-cudnn-frontend"
    )

return out, max_logit, (None, None, None, d_softmax_offset)


def _score_mod_causal(score_mod_graph, score_tensor, tensors):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly recommend that similar to the CP tests we have a separate Flex attention test file. Firstly for modularization and secondly because the Flex attention tests do not really end up using the test_dot_product_attention() base test like other DPA tests in the file do so there's no code reuse reasons for it either.

These isolated ~800 lines of code can sit in it's own file if it isn't really using of the funtions in here directly but writing the flex tests as "new" tests or else the flex tests must reuse the DPA setup in here and integrate into that.

I've also shared more details on this in my comment in the backends.py file

cc: @cyanguwa

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

Thanks for creating this PR @vcherepanov-nv
This is great !

I was curious about:

  1. Do you have benchmark numbers bases on any toy test cases you might have run ? - would be good to have those in here for users of the API.
    1. native PyT flex vs TE PyT flex
    2. traditional causal TE via cuDNN vs flex expressed causal TE via cuDNN
  2. I've linked the GH issue in the PR description. Could you please update / close it appropriately when this PR is merged
    Thanks !

vcherepanov-nv and others added 2 commits May 19, 2026 21:17
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread tests/pytorch/attention/test_flex_attention.py
@vcherepanov-nv
Copy link
Copy Markdown
Collaborator Author

Thanks for the thorough review!

  1. Do you have benchmark numbers bases on any toy test cases you might have run ? - would be good to have those in here for users of the API.

    1. native PyT flex vs TE PyT flex
    2. traditional causal TE via cuDNN vs flex expressed causal TE via cuDNN

I haven't done any benchmarking. Reportedly (from a Slack thread) score_mod can lead to significant perf gains if it allows to avoid mask materialization. For causal, I think I observed cuDNN choosing exactly the same kernel with score_mod and the explicit causal flag.

  1. I've linked the GH issue in the PR description. Could you please update / close it appropriately when this PR is merged

Sure, thanks for linking!

Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A few comments;
0. Agree with all the comments from @KshitijLakhani and @cyanguwa, so just +1ed them

  1. A user doc specifying the design choices and the building blocks of graph caching would be valuable.
  2. score_mod seems like a argument more than a feature and so the error messaging could use something more substantial like "(TE/cuDNN) Flex Attention"
  3. New arguments of the form has_* in AttentionParams could be avoided. If passing score_mod, score_mod_tensors (which are hefty) is the blocker, could we create a encapsulating dataclass and pass that instead?
  4. user_supplied_seqlens is a big vague, it seems like just a derived variable - does it degenerate to mean pad_between_seqs=True?
    Among other nits

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread tests/pytorch/attention/test_flex_attention.py
Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread tests/pytorch/attention/test_flex_attention.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch/Jax/common] Flex attention via cuDNN

4 participants