Feature/gpu inference#445
Open
maflot wants to merge 8 commits intoscverse:mainfrom
Open
Conversation
Implement a fully vectorized GPU inference backend that processes all
genes simultaneously using PyTorch tensor operations, achieving 4-24x
speedup over the CPU joblib baseline on NVIDIA B200 GPUs with perfect
result concordance (LFC Pearson r=1.0, Jaccard index=1.0).
New files:
- pydeseq2/torch_inference.py: TorchInference class implementing all 8
Inference ABC methods (IRLS, alpha MLE, Wald test, LFC shrinkage,
dispersion trend, rough/moments dispersions, linear regression mu)
- pydeseq2/torch_grid_search.py: vectorized grid search fallbacks for
dispersion, beta, and shrinkage estimation (no per-gene Python loops)
- pydeseq2/gpu_utils.py: device auto-detection, GPU trimmed mean/variance
- tests/test_gpu_concordance.py: 16 tests validating GPU output against
R DESeq2 reference data across single-factor, multi-factor, continuous,
wide, and alternative hypothesis designs (2-4% tolerance)
- tests/test_gpu_specific.py: 10 tests for device placement, CPU-GPU
precision, memory release, edge cases, and multi-factor fallback
- examples/benchmark_gpu.py: wall-clock performance benchmark suite
- examples/benchmark_concordance.py: CPU-GPU concordance verification
- PERFORMANCE.md: benchmark results and methodology
Modified files:
- pydeseq2/dds.py: add inference_type ("default"|"gpu") and device
parameters with lazy TorchInference import; fix .values bug in
fit_moments_dispersions call
- pydeseq2/ds.py: DeseqStats inherits inference engine from parent
DeseqDataSet, ensuring GPU carries through to Wald tests
- pyproject.toml: add optional [gpu] dependency group (torch>=2.0.0)
Backward compatible: default behavior unchanged, PyTorch is optional.
- Fix ruff E402: move imports before warnings.filterwarnings() call in benchmark scripts - Fix ruff B007: rename unused loop variable rep -> _rep - Apply ruff format to all new files (line wrapping, whitespace) - Add PR.md with detailed pull request description
torch is an optional dependency (only needed for inference_type="gpu"). Add mypy override to skip import-not-found errors for torch.* modules, matching the pattern used by other projects with optional GPU deps.
Replace the weak 1%-LFC-only concordance test with a comprehensive exact match suite that checks all 5 result columns (log2FoldChange, stat, lfcSE, pvalue, padj) across multiple designs: - test_single_factor_exact_match: standard synthetic dataset - test_multifactor_exact_match: 3-group multi-factor design - test_scaled_exact_match[20x100, 50x500, 20x1000]: scaling sweep Each test enforces: - Hard ceiling: no gene exceeds 2% relative error (4% multi-factor) - Soft floor: at most 1 outlier gene exceeds 0.1% relative error - NaN pattern must be identical between CPU and GPU - Dispersion concordance checked via np.testing.assert_allclose
Sphinx cannot resolve the internal anndata._core.anndata.AnnData type reference from DeseqDataSet's class inheritance. This is a pre-existing issue unrelated to GPU changes. Adding it to nitpick_ignore alongside the existing torch type suppressions.
- irls(): preserve the per-gene deviance-based convergence flag from the IRLS loop instead of overwriting it. Only NaN genes are marked non-converged; genes that converged normally keep their flag. - get_device(): raise ValueError for MPS devices since all tensor ops require float64 which MPS does not support. - Soften docstrings: "fully vectorized" -> "batched tensor operations" with explicit note that multi-factor non-convergence falls back to a serial CPU loop.
Collaborator
|
There's already a similar PR for Jax support: #441 I think torch could be an additional backend. |
Author
|
Thanks for the pointer to #441, I hadn't seen the JAX PR when I started this. Torch was the natural choice for me since I'm more familiar with that ecosystem, and I wanted a working, validated implementation before opening a discussion. |
Collaborator
|
Are you on Zulip? @ilan-gold launched a thread there to continue the discussion from #441 about the right level of abstraction. https://scverse.zulipchat.com/#narrow/channel/557094-differential-gene-expression/topic/Jax.20DESeq2 |
Author
|
Yup, I am! Thanks for pointing me to the discussion! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Reference Issue or PRs
No existing issue. Opening this to gauge interest before anything else.
What does your PR implement? Be specific.
Hi! I've been using PyDESeq2 on some larger RNA-seq datasets and found myself waiting quite a bit on the CPU pipeline. I put together a GPU-accelerated inference backend using PyTorch that plugs into the existing
InferenceABC, it helped me a lot with a recent project and I wanted to see if there's interest in having this upstream.I realize this is a fairly large PR and I'm not sure what the best way to approach contributing something like this is, happy to split it up, adjust the API, or rework things based on your feedback. I figured it's better to open the conversation with a working implementation than a vague feature request.
What it does
Adds
TorchInference, a drop-in GPU backend that batches gene-level operations into tensor ops instead of per-gene joblib parallelization. Usage is a one-parameter change:PyTorch is an optional dependency — the package works exactly as before without it installed.
What's actually GPU-vectorized and what isn't
Most of the pipeline runs as batched GPU tensor operations:
lin_reg_mu,irls,alpha_mle,wald_test,fit_rough_dispersions,fit_moments_dispersions,dispersion_trend_gamma_glm, andlfc_shrink_nbinom_glmall process every gene simultaneously on the GPU.The exception is the multi-factor non-convergence fallback. The GPU grid search only supports 2-coefficient designs (intercept + one LFC). When IRLS produces NaN betas in a multi-factor design (n_coeffs > 2), those specific genes fall back to the CPU
irls_solverin a serial per-gene loop. In practice this affects a small number of genes, but it's worth noting — it's not a full GPU path for every case.CUDA only — no MPS
All tensor operations use
torch.float64throughout to match the numerical precision of the CPU scipy path. This means MPS (Apple Silicon) is not supported since MPS doesn't support float64. The code explicitly rejectsdevice="mps"with a clear error. Supported devices are"cuda","cuda:N", or"cpu"(for testing the torch code path without a GPU).Performance
Benchmarked on an NVIDIA B200 against CPU DefaultInference (all cores):
The sweet spot is 1K–20K genes (12–24x speedup). Smaller datasets are dominated by kernel launch overhead, very large ones by memory bandwidth. Peak GPU memory was 1.83 GB for the largest configuration.
Concordance
GPU results have to match CPU — if they don't the speedup is worthless. Across all dataset sizes tested:
The test suite includes:
What's changed
New files:
pydeseq2/torch_inference.py—TorchInferenceclass (all 8InferenceABC methods)pydeseq2/torch_grid_search.py— GPU grid search fallbackspydeseq2/gpu_utils.py— device detection (with MPS rejection), GPU trimmed mean/variancetests/test_gpu_concordance.py— 16 tests against R reference datatests/test_gpu_specific.py— 14 tests (exact match, device, memory, edge cases)examples/benchmark_gpu.py,examples/benchmark_concordance.pyPERFORMANCE.mdModified files (minimal):
pydeseq2/dds.py— addedinference_typeanddeviceparameters, lazy import, fixed.valuesbug infit_moments_dispersionspydeseq2/ds.py—DeseqStatsinherits inference engine from parentDeseqDataSetpyproject.toml— optional[gpu]dep group, mypy override for torchdocs/source/conf.py— added pre-existinganndata.AnnDatatonitpick_ignoreCI status: All checks pass (lint, format, mypy, 95 tests, docs build)
Design notes
TorchInferenceimplements the sameInferenceABC — no changes to the abstract interfacetorchis lazily imported only wheninference_type="gpu"is usedlfc_shrink_nbinom_glmHessian intentionally matches a broadcasting quirk in the CPU implementation for concordance (documented in code)Things I'm unsure about
inference_typeparameter name and API surface — open to suggestionsThanks for building PyDESeq2 !