Skip to content

[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020

Open
vedaanta wants to merge 4 commits into
NVIDIA:mainfrom
vedaanta:vedaanta/te-fp8-vs-f16-shrink-b1
Open

[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020
vedaanta wants to merge 4 commits into
NVIDIA:mainfrom
vedaanta:vedaanta/te-fp8-vs-f16-shrink-b1

Conversation

@vedaanta
Copy link
Copy Markdown

@vedaanta vedaanta commented May 21, 2026

The 9 fp8_9..fp8_17 configs in model_configs_fp8_vs_f16 use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in test_dpa_fp8_vs_f16 materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue).

Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main):

per-test peak torch.cuda.max_memory_allocated:
before: 70.0 GiB (fp8_14)
after : 36.1 GiB (fp8_14) -48%
per-test peak nvidia-smi memory.used:
before: 96.8 GiB
after : 51.3 GiB -47%
test outcome (B200, develop FE, this TE):
identical 618F / 2196P / 891S, wall time within ~3%

The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards.

fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR shrinks nine model_configs_fp8_vs_f16 entries (fp8_9–fp8_17) to reduce peak GPU memory for the FP8-vs-F16 attention test suite, which was hitting ~70 GiB per test and causing OOM on Blackwell (~91 GB hardware). The strategy is primarily halving S (and in some cases B) so that the full (B, H, S, S) attention matrix materialized in bf16 fits on 80 GB cards.

  • Multiple configs reduce S (4096→2048 for fp8_9/10/11, 8192→4096 for fp8_14/17) in addition to or instead of reducing B, and fp8_10's B actually increases from 1→2, so the change is broader than "only B is smaller" as stated in the PR description.
  • fp8_13 silently gains num_gqa_groups=4, shifting it from a pure SWA-only test to a GQA+SWA combined test; this removes the only "SWA without GQA at 32 heads" coverage in the suite.
  • fp8_19/20 (B=2, S=2048) are intentionally left unchanged as per the description.

Confidence Score: 4/5

Safe to merge for the memory-reduction goal; one test config quietly changes its feature combination instead of just shrinking its size.

The memory reduction is well-motivated and measured. The only concern is fp8_13, which adds num_gqa_groups=4 while simultaneously keeping the same B and S — this changes which feature combination is exercised rather than just reducing resource usage, and the SWA-without-GQA at 32 heads scenario is no longer covered. Everything else is a straightforward size reduction.

tests/pytorch/attention/test_attention.py — specifically the fp8_13 entry, which changed in feature coverage, not just size.

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Reduces batch size (B) and/or sequence length (S) across fp8_9–fp8_17 to cut peak VRAM; fp8_13 silently gains num_gqa_groups=4, changing the test from SWA-only to GQA+SWA and dropping the original SWA-without-GQA coverage at 32 heads

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fp8_vs_f16 test suite] --> B[fp8_9 to fp8_11 H=128 D=192]
    A --> C[fp8_12 to fp8_13 H=32 D=128 S=8192]
    A --> D[fp8_14 to fp8_17 H=64 D=64]
    A --> E[fp8_18 to fp8_20 unchanged]
    B --> B1[fp8_9 B=2 S=2048 no mask]
    B --> B2[fp8_10 B=2 S=2048 causal]
    B --> B3[fp8_11 B=2 S=2048 causal_bottom_right]
    C --> C1[fp8_12 B=1 S=8192 GQA]
    C --> C2[fp8_13 B=2 S=8192 GQA+SWA was SWA-only]
    D --> D1[fp8_14 B=2 S=4096 GQA]
    D --> D2[fp8_15 B=1 S=8192 SWA]
    D --> D3[fp8_16 B=1 S=8192 GQA+learnable]
    D --> D4[fp8_17 B=2 S=4096 SWA+learnable]
Loading

Reviews (4): Last reviewed commit: "tests/attention: black format fp8_13 Mod..." | Re-trigger Greptile

vedaanta and others added 3 commits May 21, 2026 15:35
The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes
(B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference
comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the
full (B, H, S, S) attention matrix in bf16, and keeps a handful of them
live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64
the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB
cards (H100) and pushes the suite into OOM territory on Blackwell (~91
GB measured with the cuDNN caching allocator residue).

Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured
on B200 (SM_100, cuDNN 9.23, TE main):

  per-test peak `torch.cuda.max_memory_allocated`:
     before: 70.0 GiB (fp8_14)
     after : 36.1 GiB (fp8_14)         -48%
  per-test peak `nvidia-smi memory.used`:
     before: 96.8 GiB
     after : 51.3 GiB                  -47%
  test outcome (B200, develop FE, this TE):
     identical 618F / 2196P / 891S, wall time within ~3%

The shrunk configs still exercise every distinct shape/mask/SWA/GQA
combination that the originals did -- only B is smaller. The suite now
fits comfortably on 80 GB cards.

fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small
(~few GiB) and the larger batch is useful coverage for padding_causal.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
@vedaanta vedaanta force-pushed the vedaanta/te-fp8-vs-f16-shrink-b1 branch from 1a59d59 to c3f1e50 Compare May 21, 2026 22:36
Line was 105 chars; black requires <=100 with the project's preview+
string_processing settings.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !
Combinations discussed and adjusted offline to not choose the hard hammer approach of making B=1 for all but instead diversify across B, S, H and D

Good to merge after CI passes

@KshitijLakhani KshitijLakhani changed the title tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 [PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM May 21, 2026
@KshitijLakhani
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants