Skip to content

Feature/gpu inference#445

Open
maflot wants to merge 8 commits intoscverse:mainfrom
maflot:feature/gpu-inference
Open

Feature/gpu inference#445
maflot wants to merge 8 commits intoscverse:mainfrom
maflot:feature/gpu-inference

Conversation

@maflot
Copy link
Copy Markdown

@maflot maflot commented Apr 20, 2026

Reference Issue or PRs

No existing issue. Opening this to gauge interest before anything else.

What does your PR implement? Be specific.

Hi! I've been using PyDESeq2 on some larger RNA-seq datasets and found myself waiting quite a bit on the CPU pipeline. I put together a GPU-accelerated inference backend using PyTorch that plugs into the existing Inference ABC, it helped me a lot with a recent project and I wanted to see if there's interest in having this upstream.

I realize this is a fairly large PR and I'm not sure what the best way to approach contributing something like this is, happy to split it up, adjust the API, or rework things based on your feedback. I figured it's better to open the conversation with a working implementation than a vague feature request.

What it does

Adds TorchInference, a drop-in GPU backend that batches gene-level operations into tensor ops instead of per-gene joblib parallelization. Usage is a one-parameter change:

dds = DeseqDataSet(
    counts=counts_df,
    metadata=metadata,
    design="~condition",
    inference_type="gpu",  # this is the only change
)
dds.deseq2()

PyTorch is an optional dependency — the package works exactly as before without it installed.

What's actually GPU-vectorized and what isn't

Most of the pipeline runs as batched GPU tensor operations: lin_reg_mu, irls, alpha_mle, wald_test, fit_rough_dispersions, fit_moments_dispersions, dispersion_trend_gamma_glm, and lfc_shrink_nbinom_glm all process every gene simultaneously on the GPU.

The exception is the multi-factor non-convergence fallback. The GPU grid search only supports 2-coefficient designs (intercept + one LFC). When IRLS produces NaN betas in a multi-factor design (n_coeffs > 2), those specific genes fall back to the CPU irls_solver in a serial per-gene loop. In practice this affects a small number of genes, but it's worth noting — it's not a full GPU path for every case.

CUDA only — no MPS

All tensor operations use torch.float64 throughout to match the numerical precision of the CPU scipy path. This means MPS (Apple Silicon) is not supported since MPS doesn't support float64. The code explicitly rejects device="mps" with a clear error. Supported devices are "cuda", "cuda:N", or "cpu" (for testing the torch code path without a GPU).

Performance

Benchmarked on an NVIDIA B200 against CPU DefaultInference (all cores):

Samples Genes CPU (s) GPU (s) Speedup
10 500 0.72 0.17 4.3x
20 1,000 1.66 0.14 11.9x
50 5,000 5.50 0.23 23.9x
100 10,000 6.88 0.34 20.1x
200 20,000 10.79 0.69 15.6x
500 30,000 9.78 2.43 4.0x

The sweet spot is 1K–20K genes (12–24x speedup). Smaller datasets are dominated by kernel launch overhead, very large ones by memory bandwidth. Peak GPU memory was 1.83 GB for the largest configuration.

Concordance

GPU results have to match CPU — if they don't the speedup is worthless. Across all dataset sizes tested:

  • LFC Pearson r: 1.000000
  • P-value Spearman r: 1.000000
  • Max LFC relative error: < 0.04%
  • Jaccard index (significant genes at padj < 0.05): 1.0

The test suite includes:

  • 16 R-concordance tests validating GPU output against the same R DESeq2 reference CSVs the existing tests use (2% tolerance single-factor, 4% multi-factor)
  • 6 CPU-vs-GPU exact match tests checking all 5 result columns (LFC, stat, SE, pvalue, padj) across single-factor, multi-factor, and 3 dataset sizes (20x100, 50x500, 20x1000). These enforce a 2% hard ceiling per gene and verify >99% of genes agree within 0.1%.
  • 8 GPU-specific tests for device placement, float64 verification, memory cleanup, and edge cases (all-zero genes, large counts, 1000-gene scaling, multi-factor designs)
What's changed

New files:

  • pydeseq2/torch_inference.pyTorchInference class (all 8 Inference ABC methods)
  • pydeseq2/torch_grid_search.py — GPU grid search fallbacks
  • pydeseq2/gpu_utils.py — device detection (with MPS rejection), GPU trimmed mean/variance
  • tests/test_gpu_concordance.py — 16 tests against R reference data
  • tests/test_gpu_specific.py — 14 tests (exact match, device, memory, edge cases)
  • examples/benchmark_gpu.py, examples/benchmark_concordance.py
  • PERFORMANCE.md

Modified files (minimal):

  • pydeseq2/dds.py — added inference_type and device parameters, lazy import, fixed .values bug in fit_moments_dispersions
  • pydeseq2/ds.pyDeseqStats inherits inference engine from parent DeseqDataSet
  • pyproject.toml — optional [gpu] dep group, mypy override for torch
  • docs/source/conf.py — added pre-existing anndata.AnnData to nitpick_ignore

CI status: All checks pass (lint, format, mypy, 95 tests, docs build)

Design notes
  • TorchInference implements the same Inference ABC — no changes to the abstract interface
  • torch is lazily imported only when inference_type="gpu" is used
  • The IRLS convergence flag is tracked per-gene (deviance-based), not replaced wholesale when any gene has a NaN
  • The lfc_shrink_nbinom_glm Hessian intentionally matches a broadcasting quirk in the CPU implementation for concordance (documented in code)
  • The L-BFGS optimizer runs on the summed loss across genes — this works because the per-gene objectives are separable, but it's different from the CPU's per-gene scipy L-BFGS-B
Things I'm unsure about
  • Whether GPU support fits the project's scope — totally understand if it's not something you want to maintain
  • The inference_type parameter name and API surface — open to suggestions
  • Whether the grid search should be generalized to N coefficients on GPU, or if the CPU fallback is acceptable
  • Whether the test/benchmark files are too heavy for the repo

Thanks for building PyDESeq2 !

maflot added 8 commits April 2, 2026 16:06
Implement a fully vectorized GPU inference backend that processes all
genes simultaneously using PyTorch tensor operations, achieving 4-24x
speedup over the CPU joblib baseline on NVIDIA B200 GPUs with perfect
result concordance (LFC Pearson r=1.0, Jaccard index=1.0).

New files:
- pydeseq2/torch_inference.py: TorchInference class implementing all 8
  Inference ABC methods (IRLS, alpha MLE, Wald test, LFC shrinkage,
  dispersion trend, rough/moments dispersions, linear regression mu)
- pydeseq2/torch_grid_search.py: vectorized grid search fallbacks for
  dispersion, beta, and shrinkage estimation (no per-gene Python loops)
- pydeseq2/gpu_utils.py: device auto-detection, GPU trimmed mean/variance
- tests/test_gpu_concordance.py: 16 tests validating GPU output against
  R DESeq2 reference data across single-factor, multi-factor, continuous,
  wide, and alternative hypothesis designs (2-4% tolerance)
- tests/test_gpu_specific.py: 10 tests for device placement, CPU-GPU
  precision, memory release, edge cases, and multi-factor fallback
- examples/benchmark_gpu.py: wall-clock performance benchmark suite
- examples/benchmark_concordance.py: CPU-GPU concordance verification
- PERFORMANCE.md: benchmark results and methodology

Modified files:
- pydeseq2/dds.py: add inference_type ("default"|"gpu") and device
  parameters with lazy TorchInference import; fix .values bug in
  fit_moments_dispersions call
- pydeseq2/ds.py: DeseqStats inherits inference engine from parent
  DeseqDataSet, ensuring GPU carries through to Wald tests
- pyproject.toml: add optional [gpu] dependency group (torch>=2.0.0)

Backward compatible: default behavior unchanged, PyTorch is optional.
- Fix ruff E402: move imports before warnings.filterwarnings() call
  in benchmark scripts
- Fix ruff B007: rename unused loop variable rep -> _rep
- Apply ruff format to all new files (line wrapping, whitespace)
- Add PR.md with detailed pull request description
torch is an optional dependency (only needed for inference_type="gpu").
Add mypy override to skip import-not-found errors for torch.* modules,
matching the pattern used by other projects with optional GPU deps.
Replace the weak 1%-LFC-only concordance test with a comprehensive
exact match suite that checks all 5 result columns (log2FoldChange,
stat, lfcSE, pvalue, padj) across multiple designs:

- test_single_factor_exact_match: standard synthetic dataset
- test_multifactor_exact_match: 3-group multi-factor design
- test_scaled_exact_match[20x100, 50x500, 20x1000]: scaling sweep

Each test enforces:
- Hard ceiling: no gene exceeds 2% relative error (4% multi-factor)
- Soft floor: at most 1 outlier gene exceeds 0.1% relative error
- NaN pattern must be identical between CPU and GPU
- Dispersion concordance checked via np.testing.assert_allclose
Sphinx cannot resolve the internal anndata._core.anndata.AnnData type
reference from DeseqDataSet's class inheritance. This is a pre-existing
issue unrelated to GPU changes. Adding it to nitpick_ignore alongside
the existing torch type suppressions.
- irls(): preserve the per-gene deviance-based convergence flag from
  the IRLS loop instead of overwriting it. Only NaN genes are marked
  non-converged; genes that converged normally keep their flag.
- get_device(): raise ValueError for MPS devices since all tensor ops
  require float64 which MPS does not support.
- Soften docstrings: "fully vectorized" -> "batched tensor operations"
  with explicit note that multi-factor non-convergence falls back to
  a serial CPU loop.
@maflot maflot requested a review from BorisMuzellec as a code owner April 20, 2026 13:30
@grst
Copy link
Copy Markdown
Collaborator

grst commented Apr 21, 2026

There's already a similar PR for Jax support: #441

I think torch could be an additional backend.

@maflot
Copy link
Copy Markdown
Author

maflot commented Apr 21, 2026

Thanks for the pointer to #441, I hadn't seen the JAX PR when I started this. Torch was the natural choice for me since I'm more familiar with that ecosystem, and I wanted a working, validated implementation before opening a discussion.
Happy to coordinate with the JAX work if multiple backends are on the table.

@grst
Copy link
Copy Markdown
Collaborator

grst commented Apr 21, 2026

Are you on Zulip? @ilan-gold launched a thread there to continue the discussion from #441 about the right level of abstraction.

https://scverse.zulipchat.com/#narrow/channel/557094-differential-gene-expression/topic/Jax.20DESeq2

@maflot
Copy link
Copy Markdown
Author

maflot commented Apr 21, 2026

Yup, I am! Thanks for pointing me to the discussion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants