-
Notifications
You must be signed in to change notification settings - Fork 7k
add SP support for _flash_3_varlen_hub backend
#13809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
40618ff
330eb13
a7a3b9c
83ef0c7
e21870c
363f8fd
2a3c2b5
2a0dfb1
aceb39d
300427f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -340,20 +340,22 @@ class _HubKernelConfig: | |
| AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( | ||
| repo_id="kernels-community/flash-attn3", | ||
| function_attr="flash_attn_varlen_func", | ||
| wrapped_forward_attr="flash_attn_interface._flash_attn_forward", | ||
| wrapped_backward_attr="flash_attn_interface._flash_attn_backward", | ||
| version=1, | ||
| ), | ||
| AttentionBackendName.FLASH_HUB: _HubKernelConfig( | ||
| repo_id="kernels-community/flash-attn2", | ||
| function_attr="flash_attn_func", | ||
| wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", | ||
| wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", | ||
| wrapped_forward_attr="flash_attn_interface._flash_attn_forward", | ||
| wrapped_backward_attr="flash_attn_interface._flash_attn_backward", | ||
| version=1, | ||
| ), | ||
| AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( | ||
| repo_id="kernels-community/flash-attn2", | ||
| function_attr="flash_attn_varlen_func", | ||
| wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", | ||
| wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", | ||
| wrapped_forward_attr="flash_attn_interface._flash_attn_varlen_forward", | ||
| wrapped_backward_attr="flash_attn_interface._flash_attn_varlen_backward", | ||
| version=1, | ||
| ), | ||
| AttentionBackendName.SAGE_HUB: _HubKernelConfig( | ||
|
|
@@ -1213,7 +1215,7 @@ def _flash_attention_hub_forward_op( | |
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_forward_fn is None or wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` " | ||
| "Flash attention hub kernels must expose `_flash_attn_forward` and `_flash_attn_backward` " | ||
| "for context parallel execution." | ||
| ) | ||
|
|
||
|
|
@@ -1267,7 +1269,7 @@ def _flash_attention_hub_backward_op( | |
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." | ||
| "Flash attention hub kernels must expose `_flash_attn_backward` for context parallel execution." | ||
| ) | ||
|
|
||
| query, key, value, out, lse, rng_state = ctx.saved_tensors | ||
|
|
@@ -1325,8 +1327,8 @@ def _flash_varlen_attention_hub_forward_op( | |
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_forward_fn is None or wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and " | ||
| "`_wrapped_flash_attn_varlen_backward` for context parallel execution." | ||
| "Flash attention varlen hub kernels must expose `_flash_attn_varlen_forward` and " | ||
| "`_flash_attn_varlen_backward` for context parallel execution." | ||
| ) | ||
|
|
||
| if scale is None: | ||
|
|
@@ -1419,7 +1421,7 @@ def _flash_varlen_attention_hub_backward_op( | |
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` " | ||
| "Flash attention varlen hub kernels must expose `_flash_attn_varlen_backward` " | ||
| "for context parallel execution." | ||
| ) | ||
|
|
||
|
|
@@ -1612,6 +1614,194 @@ def _flash_attention_3_hub_backward_op( | |
| return grad_query, grad_key, grad_value | ||
|
|
||
|
|
||
| def _flash_attention_3_varlen_hub_forward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float | None = None, | ||
| enable_gqa: bool = False, | ||
| return_lse: bool = False, | ||
| _save_ctx: bool = True, | ||
| _parallel_config: "ParallelConfig" | None = None, | ||
| *, | ||
| window_size: tuple[int, int] = (-1, -1), | ||
| softcap: float = 0.0, | ||
| num_splits: int = 1, | ||
| pack_gqa: bool | None = None, | ||
| deterministic: bool = False, | ||
| sm_margin: int = 0, | ||
| ): | ||
| if dropout_p != 0.0: | ||
| raise ValueError("`dropout_p` is not yet supported for flash-attn 3 varlen hub kernels.") | ||
| if enable_gqa: | ||
| raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 varlen hub kernels.") | ||
|
|
||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] | ||
| wrapped_forward_fn = config.wrapped_forward_fn | ||
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_forward_fn is None or wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_forward` and " | ||
| "`flash_attn_interface._flash_attn_backward` for context parallel execution." | ||
| ) | ||
|
|
||
| if scale is None: | ||
| scale = query.shape[-1] ** (-0.5) | ||
|
|
||
| batch_size, seq_len_q, num_heads, _ = query.shape | ||
| _, seq_len_kv, _, _ = key.shape | ||
|
|
||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) | ||
| ) | ||
| indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() | ||
| query_packed = query.flatten(0, 1) | ||
| key_packed = key.reshape(-1, *key.shape[2:])[indices_k] | ||
| value_packed = value.reshape(-1, *value.shape[2:])[indices_k] | ||
| max_seqlen_q = seq_len_q | ||
| else: | ||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) | ||
| ) | ||
| query_packed = query.flatten(0, 1) | ||
| key_packed = key.flatten(0, 1) | ||
| value_packed = value.flatten(0, 1) | ||
| seqlens_k = None | ||
|
|
||
| out_packed, softmax_lse, *_ = wrapped_forward_fn( | ||
| query_packed, | ||
| key_packed, | ||
| value_packed, | ||
| None, # k_new | ||
| None, # v_new | ||
| None, # qv | ||
| None, # out_ | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, # cu_seqlens_k_new | ||
| None, # seqused_q | ||
| None, # seqused_k | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| None, # page_table | ||
| None, # kv_batch_idx | ||
| None, # leftpad_k | ||
| None, # rotary_cos | ||
| None, # rotary_sin | ||
| None, # seqlens_rotary | ||
| None, # q_descale | ||
| None, # k_descale | ||
| None, # v_descale | ||
| scale, | ||
| causal=is_causal, | ||
| window_size_left=window_size[0], | ||
| window_size_right=window_size[1], | ||
| attention_chunk=0, | ||
| softcap=softcap, | ||
| rotary_interleaved=True, | ||
| scheduler_metadata=None, | ||
| num_splits=num_splits, | ||
| pack_gqa=pack_gqa, | ||
| sm_margin=sm_margin, | ||
| ) | ||
|
|
||
| out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) | ||
|
|
||
| if _save_ctx: | ||
| ctx.save_for_backward( | ||
| query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k | ||
| ) | ||
| ctx.seqlens_k = seqlens_k # None if unmasked | ||
| ctx.indices_k = indices_k if attn_mask is not None else None | ||
| ctx.max_seqlen_q = max_seqlen_q | ||
| ctx.max_seqlen_k = max_seqlen_k | ||
| ctx.batch_size = batch_size | ||
| ctx.seq_len_q = seq_len_q | ||
| ctx.seq_len_kv = seq_len_kv | ||
| ctx.num_heads = num_heads | ||
| ctx.scale = scale | ||
| ctx.is_causal = is_causal | ||
| ctx.window_size = window_size | ||
| ctx.softcap = softcap | ||
| ctx.deterministic = deterministic | ||
| ctx.sm_margin = sm_margin | ||
|
|
||
| # softmax_lse in varlen mode: (num_heads, total_q) -> (batch_size, seq_len_q, num_heads) | ||
| lse_sp = softmax_lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous() | ||
|
|
||
| return (out, lse_sp) if return_lse else out | ||
|
|
||
|
|
||
| def _flash_attention_3_varlen_hub_backward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| grad_out: torch.Tensor, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] | ||
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_backward` " | ||
| "for context parallel execution." | ||
| ) | ||
|
|
||
| query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors | ||
|
|
||
| grad_out_packed = grad_out.flatten(0, 1) | ||
| grad_query, grad_key, grad_value = ( | ||
| torch.empty_like(query_packed), | ||
| torch.empty_like(key_packed), | ||
| torch.empty_like(value_packed), | ||
| ) | ||
|
|
||
| wrapped_backward_fn( | ||
| grad_out_packed, | ||
| query_packed, | ||
| key_packed, | ||
| value_packed, | ||
| out_packed, | ||
| softmax_lse, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, | ||
| None, # seqused_q, seqused_k | ||
| ctx.max_seqlen_q, | ||
| ctx.max_seqlen_k, | ||
| grad_query, | ||
| grad_key, | ||
| grad_value, | ||
| ctx.scale, | ||
| ctx.is_causal, | ||
| ctx.window_size[0], | ||
| ctx.window_size[1], | ||
| ctx.softcap, | ||
| ctx.deterministic, | ||
| ctx.sm_margin, | ||
| ) | ||
|
|
||
| grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:]) | ||
|
|
||
| if ctx.seqlens_k is not None: | ||
| grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) | ||
| grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) | ||
| else: | ||
| grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:]) | ||
| grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:]) | ||
|
|
||
| grad_query = grad_query[..., : grad_out.shape[-1]] | ||
| grad_key = grad_key[..., : grad_out.shape[-1]] | ||
| grad_value = grad_value[..., : grad_out.shape[-1]] | ||
|
|
||
| return grad_query, grad_key, grad_value | ||
|
|
||
|
|
||
| def _sage_attention_forward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| query: torch.Tensor, | ||
|
|
@@ -3007,7 +3197,7 @@ def _flash_attention_3_hub( | |
| @_AttentionBackendRegistry.register( | ||
| AttentionBackendName._FLASH_3_VARLEN_HUB, | ||
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], | ||
| supports_context_parallel=False, | ||
| supports_context_parallel=True, | ||
| ) | ||
| def _flash_attention_3_varlen_hub( | ||
| query: torch.Tensor, | ||
|
|
@@ -3019,41 +3209,73 @@ def _flash_attention_3_varlen_hub( | |
| return_lse: bool = False, | ||
| _parallel_config: "ParallelConfig" | None = None, | ||
| ) -> torch.Tensor: | ||
| if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1: | ||
| raise NotImplementedError("`ring_degree > 1` is not yet supported for the _FLASH_3_VARLEN_HUB backend.") | ||
|
|
||
| lse = None | ||
| batch_size, seq_len_q, _, _ = query.shape | ||
| _, seq_len_kv, _, _ = key.shape | ||
|
|
||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
|
|
||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen( | ||
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | ||
| ) | ||
| ) | ||
|
|
||
| key_valid, value_valid = [], [] | ||
| for b in range(batch_size): | ||
| valid_len = seqlens_k[b] | ||
| key_valid.append(key[b, :valid_len]) | ||
| value_valid.append(value[b, :valid_len]) | ||
|
Comment on lines
-3025
to
-3038
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like should come under
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a little refactor in this PR, and basically it is doing the same thing as |
||
| if _parallel_config is None: | ||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) | ||
| ) | ||
| indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() | ||
| key_packed = key.reshape(-1, *key.shape[2:])[indices_k] | ||
| value_packed = value.reshape(-1, *value.shape[2:])[indices_k] | ||
| else: | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) | ||
| ) | ||
| key_packed = key.flatten(0, 1) | ||
| value_packed = value.flatten(0, 1) | ||
|
|
||
| query_packed = query.flatten(0, 1) | ||
| key_packed = torch.cat(key_valid, dim=0) | ||
| value_packed = torch.cat(value_valid, dim=0) | ||
| query_packed = query.flatten(0, 1) | ||
|
|
||
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn | ||
| out, lse, *_ = func( | ||
| q=query_packed, | ||
| k=key_packed, | ||
| v=value_packed, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| ) | ||
| out = out.unflatten(0, (batch_size, -1)) | ||
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn | ||
| out = func( | ||
| q=query_packed, | ||
| k=key_packed, | ||
| v=value_packed, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| return_attn_probs=return_lse, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an extra argument?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we following the other implementation's style. And actually So it will only return single tensor (see https://github.com/huggingface/kernels-community/blob/main/flash-attn3/torch-ext/flash_attn3/flash_attn_interface.py#L691), the previous code should not be runnable.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: The original code always unpacked
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
| if return_lse: | ||
| out, lse, *_ = out | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to initialize
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a guard similar to |
||
| out = out.unflatten(0, (batch_size, -1)) | ||
| else: | ||
| forward_op = functools.partial( | ||
| _flash_attention_3_varlen_hub_forward_op, | ||
| window_size=(-1, -1), | ||
| softcap=0.0, | ||
| num_splits=1, | ||
| pack_gqa=None, | ||
| deterministic=False, | ||
| sm_margin=0, | ||
| ) | ||
| out = _templated_context_parallel_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| 0.0, | ||
| is_causal, | ||
| scale, | ||
| False, | ||
| return_lse, | ||
| forward_op=forward_op, | ||
| backward_op=_flash_attention_3_varlen_hub_backward_op, | ||
| _parallel_config=_parallel_config, | ||
| ) | ||
| if return_lse: | ||
| out, lse = out | ||
|
|
||
| return (out, lse) if return_lse else out | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The non-varlen
_flash_attention_3_hub_forward_opuses keyword arguments for the trailing parameters (causal=is_causal,window_size_left=window_size[0], etc.), but here everything is passed positionally with no inline comments explaining what eachNonecorresponds to. This makes the code harder to audit and fragile if the upstream signature changes.Consider either:
Nonevalues (like the non-varlen version does with# k_new, v_new,# cu_seqlens_q/k/k_new, etc.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice suggestion, updated.