Skip to content

Use qr_multiply to allow better reuse of factorisation#2244

Merged
YigitElma merged 5 commits into
PlasmaControl:masterfrom
jpbrodrick89:jpb/qr_multiply
Jun 19, 2026
Merged

Use qr_multiply to allow better reuse of factorisation#2244
YigitElma merged 5 commits into
PlasmaControl:masterfrom
jpbrodrick89:jpb/qr_multiply

Conversation

@jpbrodrick89

Copy link
Copy Markdown
Contributor

@YigitElma re your jax discussion this is not a perfectly optimal solution to your problem but should be a nice universal 2-4x improvement (potentially more if your Jacobians are very tall). There are two flavours of improvement and they stack:

  1. Replace qr(mode="economic") + Q.T @ r with qr_multiply which avoids materialising Q, this provides ~1.5x speedup in each QR solve based on rough benchmarks.
  2. Reuse QR factorisation of undamped Jacobian, this reduces the computational cost of each alpha-solve by a factor (m+n)/2n (where height m > n) at the cost of one additional extra initial solve

$$ \begin{pmatrix} J \\ \sqrt{a} I \end{pmatrix}∆x = \begin{pmatrix} Q_0 \begin{pmatrix} R_0 \\ 0 \end{pmatrix}\\ \sqrt{a} I \end{pmatrix} $$

$$ \implies \begin{pmatrix} Q_0^{(\text{thin})T} & 0 \\ 0 & I \end{pmatrix} \begin{pmatrix} J \\ \sqrt{a} I \end{pmatrix}= \begin{pmatrix} R_0 \\ \sqrt{a} I \end{pmatrix} = Q_1 \begin{pmatrix} R_1 \\ 0 \end{pmatrix} $$

So all one needs to do is a QR of the 2n x n matrix formed by vertically stacking $R_0$ and $\sqrt{alpha}I$. Alternatively, you could view this as Cholesky update and either scan over jax.lax.linalg.cholesky_update (stable but likely as slow as your homebrewed version) or form the Gram matrix $R_0^T R_0 + \alpha I$ which is probably the fastest option but as always will struggle with ill-conditioning. What you ideally want is a wrapper for something like Lapack's dtpqrt which is specifically designed for factorising a stacked triangular plus trapezoidal matrix but I'm not sure if there's a GPU equivalent.

Note I implemented the qr_multiply in jax-ml/jax#35104 so do shout if there are any issues. ⚠️ qr_multiply is not currently differentiable but it will be when jax-ml/jax#36357 is eventually approved. If you are interested in running on GPU also note jax-ml/jax#36575 which can provide a significant performance improvement when using qr_multiply (the default cusolver ormqr loses against orgqr + gemm for smaller Jacobians, this pure jax optimisation is faster pretty much across the board).

I did not adapt the wide cases (despite them also benefitting by a similar amount) because the API for the jax ormqr does not currently match the scipy one (because jax doesn't typically wrap scipy.linalg.lapack). If you decide you want to write an adapter it shouldn't be too painful.

@github-actions

github-actions Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    5.57 %    |     3.908e+03      |     4.125e+03      |    217.70    |       31.85        |       29.45        |
  test_proximal_jac_w7x_with_eq_update   |   -0.18 %    |     6.584e+03      |     6.573e+03      |    -11.56    |       155.72       |       154.02       |
  test_proximal_freeb_jac                |   -0.11 %    |     1.338e+04      |     1.336e+04      |    -14.71    |       80.93        |       79.03        |
  test_proximal_freeb_jac_blocked        |    0.11 %    |     7.694e+03      |     7.703e+03      |     8.65     |       70.20        |       70.11        |
  test_proximal_freeb_jac_batched        |    0.17 %    |     7.651e+03      |     7.664e+03      |    12.66     |       69.70        |       70.00        |
  test_proximal_jac_ripple               |    0.26 %    |     3.574e+03      |     3.583e+03      |     9.44     |       53.97        |       54.76        |
  test_proximal_jac_ripple_bounce1d      |    1.17 %    |     3.749e+03      |     3.793e+03      |    43.68     |       68.74        |       69.97        |
  test_eq_solve                          |   -6.37 %    |     2.077e+03      |     1.944e+03      |   -132.31    |       58.69        |       87.74        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@jpbrodrick89

Copy link
Copy Markdown
Contributor Author

Sorry this will require a minimum python version bump to 3.11 for jax>=0.10.0 compatibility (unless we write an alternative qr_multiply), as EOL is Oct 2026 hopefully that's acceptable, can push the changes if you're comfortable with that.

@YigitElma

Copy link
Copy Markdown
Collaborator

Hi @jpbrodrick89, thank you for taking the time to explain and implement it! I will try to locally benchmark this for a couple of different problem sizes on the GPU and CPU. I think this can solve the issue we faced in #1755. Differentiability is not currently a problem; we are not differentiating over optimizations yet.

Although this new function is more efficient, I don't think we would like to bump our min jax version to 0.10.0 yet. We will have to first solve a couple of issues with JAX 0.10+ versions in #2235. After that, I think we can have a wrapper in desc.backend for this, such that if the user has a new version of jax, it can use this method; otherwise, older manual QR. We were planning to bump the Python version soon, so there shouldn't be a problem with that.

Having said that, I think the improvement in trust_region_step_exact_qr besides qr_update can still give a nice performance improvement.

Comment thread desc/optimize/least_squares.py Outdated
Comment thread desc/backend.py Outdated
@jpbrodrick89

Copy link
Copy Markdown
Contributor Author

Great feedback thanks @YigitElma , should be fully backwards compatible now but the full Q materialisation free speedup will only be unlocked when you upgrade to Jax 0.10.0. Let me know if anything else is needed.

@YigitElma YigitElma requested review from a team, YigitElma, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team June 17, 2026 14:21
@YigitElma

YigitElma commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

@jpbrodrick89 Thanks! This looks good to me. #2245 will solve the issue with CI.

We generally require 2 approvals from active developers. I will discuss this with others during our dev-meeting tomorrow.

@YigitElma

YigitElma commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Hi @jpbrodrick89, there was a minor bug with using the variable name z again in the Augmented Lagrangian least-squares optimizer. We were also using z for the augmented state vector. I renamed the variables z -> Qt_f like convention for both optimizers.

I also reordered the changelog because we had a recent release.

@YigitElma YigitElma added the run_benchmarks Run timing benchmarks on this PR against current master branch label Jun 18, 2026
@codecov

codecov Bot commented Jun 18, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 79.31034% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.32%. Comparing base (c617960) to head (3354362).

Files with missing lines Patch % Lines
desc/optimize/aug_lagrangian_ls.py 42.85% 4 Missing ⚠️
desc/backend.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2244      +/-   ##
==========================================
- Coverage   94.33%   94.32%   -0.01%     
==========================================
  Files         101      101              
  Lines       28855    28866      +11     
==========================================
+ Hits        27220    27228       +8     
- Misses       1635     1638       +3     
Files with missing lines Coverage Δ
desc/optimize/least_squares.py 99.43% <100.00%> (+<0.01%) ⬆️
desc/optimize/tr_subproblems.py 99.43% <100.00%> (+<0.01%) ⬆️
desc/backend.py 89.16% <75.00%> (-0.59%) ⬇️
desc/optimize/aug_lagrangian_ls.py 95.06% <42.85%> (-0.81%) ⬇️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions

github-actions Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor
|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -2.83 +/- 4.01     | -2.45e-02 +/- 3.48e-02 |  8.44e-01 +/- 3.3e-02  |  8.68e-01 +/- 1.2e-02  |
 test_equilibrium_init_medres            |     -0.64 +/- 3.61     | -4.42e-02 +/- 2.51e-01 |  6.91e+00 +/- 2.0e-01  |  6.95e+00 +/- 1.5e-01  |
 test_equilibrium_init_highres           |     -0.48 +/- 5.78     | -3.77e-02 +/- 4.53e-01 |  7.80e+00 +/- 1.8e-01  |  7.84e+00 +/- 4.1e-01  |
 test_objective_compile_dshape_current   |     +2.27 +/- 2.72     | +8.95e-02 +/- 1.07e-01 |  4.03e+00 +/- 7.4e-02  |  3.95e+00 +/- 7.8e-02  |
 test_objective_compute_dshape_current   |    -11.58 +/- 5.45     | -7.89e-05 +/- 3.71e-05 |  6.02e-04 +/- 2.6e-05  |  6.81e-04 +/- 2.7e-05  |
 test_objective_jac_dshape_current       |     +8.05 +/- 26.00    | +1.84e-03 +/- 5.95e-03 |  2.47e-02 +/- 4.9e-03  |  2.29e-02 +/- 3.3e-03  |
 test_perturb_2                          |     +4.24 +/- 2.02     | +7.93e-01 +/- 3.78e-01 |  1.95e+01 +/- 3.6e-01  |  1.87e+01 +/- 1.3e-01  |
 test_proximal_jac_atf_with_eq_update    |     +1.52 +/- 1.21     | +1.87e-01 +/- 1.49e-01 |  1.25e+01 +/- 1.4e-01  |  1.23e+01 +/- 5.9e-02  |
 test_proximal_freeb_jac                 |     +0.20 +/- 2.74     | +9.33e-03 +/- 1.27e-01 |  4.63e+00 +/- 1.1e-01  |  4.62e+00 +/- 7.0e-02  |
+test_solve_fixed_iter_compiled          |    -22.83 +/- 2.08     | -2.00e+00 +/- 1.82e-01 |  6.75e+00 +/- 1.4e-01  |  8.75e+00 +/- 1.2e-01  |
 test_LinearConstraintProjection_build   |     +5.14 +/- 5.53     | +3.39e-01 +/- 3.65e-01 |  6.94e+00 +/- 2.9e-01  |  6.61e+00 +/- 2.3e-01  |
 test_objective_compute_ripple_bounce1d  |     +4.01 +/- 4.91     | +1.11e-02 +/- 1.36e-02 |  2.87e-01 +/- 1.3e-02  |  2.76e-01 +/- 5.0e-03  |
 test_objective_grad_ripple_bounce1d     |     +1.21 +/- 2.36     | +1.16e-02 +/- 2.28e-02 |  9.78e-01 +/- 8.9e-03  |  9.66e-01 +/- 2.1e-02  |
 test_build_transform_fft_midres         |     +0.10 +/- 2.79     | +8.94e-04 +/- 2.45e-02 |  8.77e-01 +/- 2.1e-02  |  8.76e-01 +/- 1.3e-02  |
 test_build_transform_fft_highres        |     -1.33 +/- 2.51     | -1.56e-02 +/- 2.96e-02 |  1.16e+00 +/- 1.8e-02  |  1.18e+00 +/- 2.4e-02  |
 test_equilibrium_init_lowres            |     +0.15 +/- 3.30     | +9.37e-03 +/- 2.10e-01 |  6.37e+00 +/- 1.5e-01  |  6.36e+00 +/- 1.5e-01  |
 test_objective_compile_atf              |     -0.40 +/- 3.74     | -2.47e-02 +/- 2.34e-01 |  6.23e+00 +/- 2.0e-01  |  6.25e+00 +/- 1.2e-01  |
 test_objective_compute_atf              |     +5.93 +/- 16.37    | +1.25e-04 +/- 3.44e-04 |  2.23e-03 +/- 9.9e-05  |  2.10e-03 +/- 3.3e-04  |
 test_objective_jac_atf                  |     -1.94 +/- 3.65     | -3.16e-02 +/- 5.95e-02 |  1.60e+00 +/- 4.3e-02  |  1.63e+00 +/- 4.1e-02  |
 test_perturb_1                          |     +2.19 +/- 2.22     | +3.48e-01 +/- 3.52e-01 |  1.62e+01 +/- 3.1e-01  |  1.59e+01 +/- 1.7e-01  |
 test_proximal_jac_atf                   |     -0.49 +/- 1.87     | -2.61e-02 +/- 9.95e-02 |  5.29e+00 +/- 7.1e-02  |  5.31e+00 +/- 6.9e-02  |
 test_proximal_freeb_compute             |     +0.22 +/- 2.13     | +3.65e-04 +/- 3.52e-03 |  1.65e-01 +/- 2.2e-03  |  1.65e-01 +/- 2.7e-03  |
 test_solve_fixed_iter                   |     -5.79 +/- 2.75     | -1.50e+00 +/- 7.13e-01 |  2.45e+01 +/- 5.5e-01  |  2.60e+01 +/- 4.5e-01  |
 test_objective_compute_ripple           |     -1.03 +/- 3.41     | -2.44e-03 +/- 8.02e-03 |  2.33e-01 +/- 5.0e-03  |  2.35e-01 +/- 6.3e-03  |
 test_objective_grad_ripple              |     -0.52 +/- 3.31     | -4.64e-03 +/- 2.96e-02 |  8.91e-01 +/- 1.7e-02  |  8.96e-01 +/- 2.4e-02  |

Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account.

@YigitElma

Copy link
Copy Markdown
Collaborator

I can run the benchmark on GPU later, but the CPU memory benchmark on CI is as follows:
image
I tried to label which peak corresponds to what.

@jpbrodrick89

jpbrodrick89 commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

I just realised that I could use my pure jax implementation of ormqr in jax-ml/jax#36575 to provide a more efficient "fallback" (that even beats cusolver) as it doesn't depend on a rebuild of jaxlib so can probably run on any version of jax. Its not too complex and fairly self contained would you like to me to give it a whirl here or leave it as a follow up?

@YigitElma

Copy link
Copy Markdown
Collaborator

I just realised that I could use my pure jax implementation of ormqr in jax-ml/jax#36575 to provide a more efficient "fallback" (that even beats cusolver) as it doesn't depend on a rebuild of jaxlib so can probably run on any version of jax. Its not too complex and fairly self contained would you like to me to give it a whirl here or leave it as a follow up?

I think the current PR is a good improvement over the existing one. For the sake of keeping this one simple, I would vote to have your pure JAX implementation in a new PR if it is okay for you.

I am not very familiar with registering a primitive to JAX, but if I understand your PR correctly, instead of calling the cusolver ffi (which JAX does currently), you register your own fori_loop-based Householder implementation, am I correct?

DESC allows very different-sized optimizations. So, I am not sure what the net gain would be for different cases. If you would like to open a PR, I would be happy to test it.

@jpbrodrick89

Copy link
Copy Markdown
Contributor Author

Agreed, makes sense, this is enough for one PR. yes I use a blas 3 blocked python for loop which better utilises tensor cores rather than cusolver's apparently sequential. It works well on cpu but lapack is still better. I'll give it a whirl to see if theres a clean approach after this PR lands. Thanks!

@YigitElma YigitElma left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@YigitElma YigitElma added the override codecov Override codecov label Jun 18, 2026
@YigitElma

YigitElma commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Here are some GPU benchmarks. I solved precise_QA with different LMN resolution. I set the jac_chunk_size to 500 to focus on the QR region. Sharp peaks are jacobian computation and everthing in the middle is QR and trust region related. Jax version is 0.7.2 Cuda12 with my laptop RTX4080 12GB Vram.

Code
if sys.argv[2] in ["GPU", "gpu"]:
    # Set the environment variable to use the GPU
    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
    from desc import set_device

    set_device("gpu")

from desc.backend import print_backend_info
from desc.examples import get
from desc.objectives import ObjectiveFunction, ForceBalance

print_backend_info()

N = int(sys.argv[1])

eq = get("precise_QA")
eq.change_resolution(L=N, M=N, L_grid=2 * N, M_grid=2 * N)
eq.resolution_summary()
eq.set_initial_guess()
obj = ObjectiveFunction(ForceBalance(eq), jac_chunk_size=500, deriv_mode="batched")
obj.build()
print(f"Objective function deriv mode: {obj._deriv_mode}")
print(f"Objective function chunk size: {obj._jac_chunk_size}")

eq.solve(
    objective=obj,
    constraints=None,
    optimizer="lsq-exact",
    ftol=0,
    xtol=0,
    gtol=0,
    maxiter=int(sys.argv[3]),
    verbose=3,
    copy=False,
)

LM=16 N=8 (J is 19074 x 3336)

image

LMN=8 (J is 5346 x 856)

This one is way too fast to conclude on the GPU.
image

LM=20 N=8 (J is 29106 x 5188)

image

@YigitElma

Copy link
Copy Markdown
Collaborator

Btw, I cannot get reliable benchmarks for newer jax versions (I tested with 0.10.1). But I think we can check this in #2235 after solving #2246.

@YigitElma

Copy link
Copy Markdown
Collaborator

@jpbrodrick89 do you have any final changes in mind? If not, I think we can merge this.

@jpbrodrick89

jpbrodrick89 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

Hi @YigitElma, no this looks nice and clean and well tested thanks for doing a production run so quickly, looks like its doing exactly what we hoped with the big difference on the first step.

I can try the pure jax ormqr as a follow-up of you're keen to try it.

By the way, are any of your team at EPS in a couple weeks? It would be great to connect!

@YigitElma

Copy link
Copy Markdown
Collaborator

I can try the pure jax ormqr as a follow-up of you're keen to try it.

Thanks! If you create a PR, I would be happy to test it.

By the way, are any of your team at EPS in a couple weeks? It would be great to connect!

Not that I am aware of. @dpanici @f0uriest @ddudt @rahulgaur104, are any of you joining?
If that doesn't work, we would be happy to have you in our dev-meeting sometime. I can share the details with you over email if you are interested.

@YigitElma YigitElma merged commit 4a6e718 into PlasmaControl:master Jun 19, 2026
32 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

override codecov Override codecov run_benchmarks Run timing benchmarks on this PR against current master branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants