Use qr_multiply to allow better reuse of factorisation#2244
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 5.57 % | 3.908e+03 | 4.125e+03 | 217.70 | 31.85 | 29.45 |
test_proximal_jac_w7x_with_eq_update | -0.18 % | 6.584e+03 | 6.573e+03 | -11.56 | 155.72 | 154.02 |
test_proximal_freeb_jac | -0.11 % | 1.338e+04 | 1.336e+04 | -14.71 | 80.93 | 79.03 |
test_proximal_freeb_jac_blocked | 0.11 % | 7.694e+03 | 7.703e+03 | 8.65 | 70.20 | 70.11 |
test_proximal_freeb_jac_batched | 0.17 % | 7.651e+03 | 7.664e+03 | 12.66 | 69.70 | 70.00 |
test_proximal_jac_ripple | 0.26 % | 3.574e+03 | 3.583e+03 | 9.44 | 53.97 | 54.76 |
test_proximal_jac_ripple_bounce1d | 1.17 % | 3.749e+03 | 3.793e+03 | 43.68 | 68.74 | 69.97 |
test_eq_solve | -6.37 % | 2.077e+03 | 1.944e+03 | -132.31 | 58.69 | 87.74 |For the memory plots, go to the summary of |
|
Sorry this will require a minimum python version bump to 3.11 for jax>=0.10.0 compatibility (unless we write an alternative qr_multiply), as EOL is Oct 2026 hopefully that's acceptable, can push the changes if you're comfortable with that. |
|
Hi @jpbrodrick89, thank you for taking the time to explain and implement it! I will try to locally benchmark this for a couple of different problem sizes on the GPU and CPU. I think this can solve the issue we faced in #1755. Differentiability is not currently a problem; we are not differentiating over optimizations yet. Although this new function is more efficient, I don't think we would like to bump our min Having said that, I think the improvement in |
|
Great feedback thanks @YigitElma , should be fully backwards compatible now but the full Q materialisation free speedup will only be unlocked when you upgrade to Jax 0.10.0. Let me know if anything else is needed. |
|
@jpbrodrick89 Thanks! This looks good to me. #2245 will solve the issue with CI. We generally require 2 approvals from active developers. I will discuss this with others during our dev-meeting tomorrow. |
|
Hi @jpbrodrick89, there was a minor bug with using the variable name I also reordered the changelog because we had a recent release. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2244 +/- ##
==========================================
- Coverage 94.33% 94.32% -0.01%
==========================================
Files 101 101
Lines 28855 28866 +11
==========================================
+ Hits 27220 27228 +8
- Misses 1635 1638 +3
🚀 New features to boost your workflow:
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | -2.83 +/- 4.01 | -2.45e-02 +/- 3.48e-02 | 8.44e-01 +/- 3.3e-02 | 8.68e-01 +/- 1.2e-02 |
test_equilibrium_init_medres | -0.64 +/- 3.61 | -4.42e-02 +/- 2.51e-01 | 6.91e+00 +/- 2.0e-01 | 6.95e+00 +/- 1.5e-01 |
test_equilibrium_init_highres | -0.48 +/- 5.78 | -3.77e-02 +/- 4.53e-01 | 7.80e+00 +/- 1.8e-01 | 7.84e+00 +/- 4.1e-01 |
test_objective_compile_dshape_current | +2.27 +/- 2.72 | +8.95e-02 +/- 1.07e-01 | 4.03e+00 +/- 7.4e-02 | 3.95e+00 +/- 7.8e-02 |
test_objective_compute_dshape_current | -11.58 +/- 5.45 | -7.89e-05 +/- 3.71e-05 | 6.02e-04 +/- 2.6e-05 | 6.81e-04 +/- 2.7e-05 |
test_objective_jac_dshape_current | +8.05 +/- 26.00 | +1.84e-03 +/- 5.95e-03 | 2.47e-02 +/- 4.9e-03 | 2.29e-02 +/- 3.3e-03 |
test_perturb_2 | +4.24 +/- 2.02 | +7.93e-01 +/- 3.78e-01 | 1.95e+01 +/- 3.6e-01 | 1.87e+01 +/- 1.3e-01 |
test_proximal_jac_atf_with_eq_update | +1.52 +/- 1.21 | +1.87e-01 +/- 1.49e-01 | 1.25e+01 +/- 1.4e-01 | 1.23e+01 +/- 5.9e-02 |
test_proximal_freeb_jac | +0.20 +/- 2.74 | +9.33e-03 +/- 1.27e-01 | 4.63e+00 +/- 1.1e-01 | 4.62e+00 +/- 7.0e-02 |
+test_solve_fixed_iter_compiled | -22.83 +/- 2.08 | -2.00e+00 +/- 1.82e-01 | 6.75e+00 +/- 1.4e-01 | 8.75e+00 +/- 1.2e-01 |
test_LinearConstraintProjection_build | +5.14 +/- 5.53 | +3.39e-01 +/- 3.65e-01 | 6.94e+00 +/- 2.9e-01 | 6.61e+00 +/- 2.3e-01 |
test_objective_compute_ripple_bounce1d | +4.01 +/- 4.91 | +1.11e-02 +/- 1.36e-02 | 2.87e-01 +/- 1.3e-02 | 2.76e-01 +/- 5.0e-03 |
test_objective_grad_ripple_bounce1d | +1.21 +/- 2.36 | +1.16e-02 +/- 2.28e-02 | 9.78e-01 +/- 8.9e-03 | 9.66e-01 +/- 2.1e-02 |
test_build_transform_fft_midres | +0.10 +/- 2.79 | +8.94e-04 +/- 2.45e-02 | 8.77e-01 +/- 2.1e-02 | 8.76e-01 +/- 1.3e-02 |
test_build_transform_fft_highres | -1.33 +/- 2.51 | -1.56e-02 +/- 2.96e-02 | 1.16e+00 +/- 1.8e-02 | 1.18e+00 +/- 2.4e-02 |
test_equilibrium_init_lowres | +0.15 +/- 3.30 | +9.37e-03 +/- 2.10e-01 | 6.37e+00 +/- 1.5e-01 | 6.36e+00 +/- 1.5e-01 |
test_objective_compile_atf | -0.40 +/- 3.74 | -2.47e-02 +/- 2.34e-01 | 6.23e+00 +/- 2.0e-01 | 6.25e+00 +/- 1.2e-01 |
test_objective_compute_atf | +5.93 +/- 16.37 | +1.25e-04 +/- 3.44e-04 | 2.23e-03 +/- 9.9e-05 | 2.10e-03 +/- 3.3e-04 |
test_objective_jac_atf | -1.94 +/- 3.65 | -3.16e-02 +/- 5.95e-02 | 1.60e+00 +/- 4.3e-02 | 1.63e+00 +/- 4.1e-02 |
test_perturb_1 | +2.19 +/- 2.22 | +3.48e-01 +/- 3.52e-01 | 1.62e+01 +/- 3.1e-01 | 1.59e+01 +/- 1.7e-01 |
test_proximal_jac_atf | -0.49 +/- 1.87 | -2.61e-02 +/- 9.95e-02 | 5.29e+00 +/- 7.1e-02 | 5.31e+00 +/- 6.9e-02 |
test_proximal_freeb_compute | +0.22 +/- 2.13 | +3.65e-04 +/- 3.52e-03 | 1.65e-01 +/- 2.2e-03 | 1.65e-01 +/- 2.7e-03 |
test_solve_fixed_iter | -5.79 +/- 2.75 | -1.50e+00 +/- 7.13e-01 | 2.45e+01 +/- 5.5e-01 | 2.60e+01 +/- 4.5e-01 |
test_objective_compute_ripple | -1.03 +/- 3.41 | -2.44e-03 +/- 8.02e-03 | 2.33e-01 +/- 5.0e-03 | 2.35e-01 +/- 6.3e-03 |
test_objective_grad_ripple | -0.52 +/- 3.31 | -4.64e-03 +/- 2.96e-02 | 8.91e-01 +/- 1.7e-02 | 8.96e-01 +/- 2.4e-02 |Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account. |
|
I just realised that I could use my pure jax implementation of ormqr in jax-ml/jax#36575 to provide a more efficient "fallback" (that even beats cusolver) as it doesn't depend on a rebuild of jaxlib so can probably run on any version of jax. Its not too complex and fairly self contained would you like to me to give it a whirl here or leave it as a follow up? |
I think the current PR is a good improvement over the existing one. For the sake of keeping this one simple, I would vote to have your pure JAX implementation in a new PR if it is okay for you. I am not very familiar with registering a primitive to JAX, but if I understand your PR correctly, instead of calling the cusolver ffi (which JAX does currently), you register your own fori_loop-based Householder implementation, am I correct? DESC allows very different-sized optimizations. So, I am not sure what the net gain would be for different cases. If you would like to open a PR, I would be happy to test it. |
|
Agreed, makes sense, this is enough for one PR. yes I use a blas 3 blocked python for loop which better utilises tensor cores rather than cusolver's apparently sequential. It works well on cpu but lapack is still better. I'll give it a whirl to see if theres a clean approach after this PR lands. Thanks! |
|
@jpbrodrick89 do you have any final changes in mind? If not, I think we can merge this. |
|
Hi @YigitElma, no this looks nice and clean and well tested thanks for doing a production run so quickly, looks like its doing exactly what we hoped with the big difference on the first step. I can try the pure jax ormqr as a follow-up of you're keen to try it. By the way, are any of your team at EPS in a couple weeks? It would be great to connect! |
Thanks! If you create a PR, I would be happy to test it.
Not that I am aware of. @dpanici @f0uriest @ddudt @rahulgaur104, are any of you joining? |




@YigitElma re your jax discussion this is not a perfectly optimal solution to your problem but should be a nice universal 2-4x improvement (potentially more if your Jacobians are very tall). There are two flavours of improvement and they stack:
qr(mode="economic") + Q.T @ rwithqr_multiplywhich avoids materialising Q, this provides ~1.5x speedup in each QR solve based on rough benchmarks.So all one needs to do is a QR of the 2n x n matrix formed by vertically stacking$R_0$ and $\sqrt{alpha}I$ . Alternatively, you could view this as Cholesky update and either scan over jax.lax.linalg.cholesky_update (stable but likely as slow as your homebrewed version) or form the Gram matrix $R_0^T R_0 + \alpha I$ which is probably the fastest option but as always will struggle with ill-conditioning. What you ideally want is a wrapper for something like Lapack's
dtpqrtwhich is specifically designed for factorising a stacked triangular plus trapezoidal matrix but I'm not sure if there's a GPU equivalent.Note I implemented the⚠️
qr_multiplyin jax-ml/jax#35104 so do shout if there are any issues.qr_multiplyis not currently differentiable but it will be when jax-ml/jax#36357 is eventually approved. If you are interested in running on GPU also note jax-ml/jax#36575 which can provide a significant performance improvement when usingqr_multiply(the default cusolverormqrloses againstorgqr+gemmfor smaller Jacobians, this pure jax optimisation is faster pretty much across the board).I did not adapt the wide cases (despite them also benefitting by a similar amount) because the API for the jax ormqr does not currently match the scipy one (because jax doesn't typically wrap
scipy.linalg.lapack). If you decide you want to write an adapter it shouldn't be too painful.