Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 262 additions & 40 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,22 @@ class _HubKernelConfig:
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward",
wrapped_forward_attr="flash_attn_interface._flash_attn_varlen_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_varlen_backward",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
Expand Down Expand Up @@ -1213,7 +1215,7 @@ def _flash_attention_hub_forward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"Flash attention hub kernels must expose `_flash_attn_forward` and `_flash_attn_backward` "
"for context parallel execution."
)

Expand Down Expand Up @@ -1267,7 +1269,7 @@ def _flash_attention_hub_backward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
"Flash attention hub kernels must expose `_flash_attn_backward` for context parallel execution."
)

query, key, value, out, lse, rng_state = ctx.saved_tensors
Expand Down Expand Up @@ -1325,8 +1327,8 @@ def _flash_varlen_attention_hub_forward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and "
"`_wrapped_flash_attn_varlen_backward` for context parallel execution."
"Flash attention varlen hub kernels must expose `_flash_attn_varlen_forward` and "
"`_flash_attn_varlen_backward` for context parallel execution."
)

if scale is None:
Expand Down Expand Up @@ -1419,7 +1421,7 @@ def _flash_varlen_attention_hub_backward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` "
"Flash attention varlen hub kernels must expose `_flash_attn_varlen_backward` "
"for context parallel execution."
)

Expand Down Expand Up @@ -1612,6 +1614,194 @@ def _flash_attention_3_hub_backward_op(
return grad_query, grad_key, grad_value


def _flash_attention_3_varlen_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 varlen hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 varlen hub kernels.")

config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_forward` and "
"`flash_attn_interface._flash_attn_backward` for context parallel execution."
)

if scale is None:
scale = query.shape[-1] ** (-0.5)

batch_size, seq_len_q, num_heads, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
query_packed = query.flatten(0, 1)
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
max_seqlen_q = seq_len_q
else:
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
query_packed = query.flatten(0, 1)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)
seqlens_k = None

out_packed, softmax_lse, *_ = wrapped_forward_fn(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The non-varlen _flash_attention_3_hub_forward_op uses keyword arguments for the trailing parameters (causal=is_causal, window_size_left=window_size[0], etc.), but here everything is passed positionally with no inline comments explaining what each None corresponds to. This makes the code harder to audit and fragile if the upstream signature changes.

Consider either:

  1. Using keyword arguments for at least the trailing parameters (like the non-varlen version does), or
  2. Adding inline comments for the positional None values (like the non-varlen version does with # k_new, v_new, # cu_seqlens_q/k/k_new, etc.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion, updated.

query_packed,
key_packed,
value_packed,
None, # k_new
None, # v_new
None, # qv
None, # out_
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
None, # seqused_q
None, # seqused_k
max_seqlen_q,
max_seqlen_k,
None, # page_table
None, # kv_batch_idx
None, # leftpad_k
None, # rotary_cos
None, # rotary_sin
None, # seqlens_rotary
None, # q_descale
None, # k_descale
None, # v_descale
scale,
causal=is_causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=0,
softcap=softcap,
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)

out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:])

if _save_ctx:
ctx.save_for_backward(
query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k
)
ctx.seqlens_k = seqlens_k # None if unmasked
ctx.indices_k = indices_k if attn_mask is not None else None
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.batch_size = batch_size
ctx.seq_len_q = seq_len_q
ctx.seq_len_kv = seq_len_kv
ctx.num_heads = num_heads
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin

# softmax_lse in varlen mode: (num_heads, total_q) -> (batch_size, seq_len_q, num_heads)
lse_sp = softmax_lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous()

return (out, lse_sp) if return_lse else out


def _flash_attention_3_varlen_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
)

query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

grad_out_packed = grad_out.flatten(0, 1)
grad_query, grad_key, grad_value = (
torch.empty_like(query_packed),
torch.empty_like(key_packed),
torch.empty_like(value_packed),
)

wrapped_backward_fn(
grad_out_packed,
query_packed,
key_packed,
value_packed,
out_packed,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
None,
None, # seqused_q, seqused_k
ctx.max_seqlen_q,
ctx.max_seqlen_k,
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)

grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:])

if ctx.seqlens_k is not None:
grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
else:
grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:])
grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:])

grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]

return grad_query, grad_key, grad_value


def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
Expand Down Expand Up @@ -3007,7 +3197,7 @@ def _flash_attention_3_hub(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_VARLEN_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_varlen_hub(
query: torch.Tensor,
Expand All @@ -3019,41 +3209,73 @@ def _flash_attention_3_varlen_hub(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1:
raise NotImplementedError("`ring_degree > 1` is not yet supported for the _FLASH_3_VARLEN_HUB backend.")

lse = None
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)

key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
Comment on lines -3025 to -3038

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like should come under if _parallel_config is None and attn_mask is not None:?

@zhtmike zhtmike Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a little refactor in this PR, and basically it is doing the same thing as
#13479 (comment)

if _parallel_config is None:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
else:
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)

query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
query_packed = query.flatten(0, 1)

func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
out, lse, *_ = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
)
out = out.unflatten(0, (batch_size, -1))
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
out = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an extra argument?

@zhtmike zhtmike Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we following the other implementation's style.

And actually return_attn_probs by default is False (see `https://github.com/huggingface/kernels-community/blob/main/flash-attn3/torch-ext/flash_attn3/flash_attn_interface.py#L648)

So it will only return single tensor (see https://github.com/huggingface/kernels-community/blob/main/flash-attn3/torch-ext/flash_attn3/flash_attn_interface.py#L691),

the previous code should not be runnable.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The original code always unpacked out, lse, *_ = func(...). Now with return_attn_probs=return_lse, when return_lse=False the return value may be different (single tensor vs tuple). Make sure flash_attn_varlen_func from flash-attn3 returns a single tensor (not a tuple) when return_attn_probs=False. The flash-attn2 varlen hub uses the same pattern, so this is likely fine, but worth verifying.

@zhtmike zhtmike Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
if return_lse:
out, lse, *_ = out

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to initialize lse = None above?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a guard similar to _native_cudnn_attention, _sage_attention, etc. To make the code in same style

out = out.unflatten(0, (batch_size, -1))
else:
forward_op = functools.partial(
_flash_attention_3_varlen_hub_forward_op,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=forward_op,
backward_op=_flash_attention_3_varlen_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out

return (out, lse) if return_lse else out

Expand Down
Loading
Loading