Skip to content
Open
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
8fb2b38
Add fft grid and raz grid to test against master
unalmis Apr 10, 2026
25164d0
remove noise by tighten tolerance
unalmis Apr 10, 2026
0d23c66
final attempt
unalmis Apr 10, 2026
5e70567
rory comment
unalmis Apr 10, 2026
945f1af
fix last commit
unalmis Apr 10, 2026
c2ecc4b
Increase correlation in discretization error for optimization
unalmis Apr 12, 2026
bb8ac6a
Merge branch 'master' into ku/test
unalmis Apr 12, 2026
7489317
.
unalmis Apr 12, 2026
e240249
increase tol for test
unalmis Apr 12, 2026
be41c58
remove not implemented todo
unalmis Apr 12, 2026
a55b170
.
unalmis Apr 12, 2026
2cff860
add back short-circuit
unalmis Apr 12, 2026
1727fba
collect redundant docs
unalmis Apr 13, 2026
339643b
Fix if statements
unalmis Apr 13, 2026
bda562a
Merge branch 'master' into ku/test
f0uriest Apr 13, 2026
ff53f80
Resolves #2162
unalmis Apr 14, 2026
83ffee6
loosen tol on test
unalmis Apr 14, 2026
ccf228f
flake8
unalmis Apr 14, 2026
f8a3515
flake8 blank line space
unalmis Apr 14, 2026
d05bda1
future proof
unalmis Apr 15, 2026
c06687d
daniel comments
unalmis Apr 15, 2026
d5a682f
fix render
unalmis Apr 16, 2026
f325bbf
Apply suggestions from code review
unalmis Apr 16, 2026
1792e91
Apply suggestions from code review
unalmis Apr 16, 2026
89479dc
Apply suggestions from code review
unalmis Apr 16, 2026
947641b
dan comment v2
unalmis Apr 16, 2026
13b6870
dan v2
unalmis Apr 16, 2026
a58c075
more dan
unalmis Apr 16, 2026
ad86912
last dan
unalmis Apr 16, 2026
f4faed4
last commit to desc
unalmis Apr 16, 2026
c42a92b
flake
unalmis Apr 16, 2026
2a5d6c2
Merge branch 'master' into ku/test
unalmis Apr 17, 2026
585d59a
Merge branch 'master' into ku/test
unalmis Apr 18, 2026
fcea971
Merge branch 'master' into ku/test
dpanici Apr 19, 2026
86f21f7
Resolves #2168
unalmis Apr 20, 2026
2c93334
remove comment
unalmis Apr 20, 2026
1b79b3e
.
unalmis Apr 21, 2026
1decd5e
clean up internal api
unalmis Apr 21, 2026
6df2ca9
clean
unalmis Apr 21, 2026
eef7938
use none
unalmis Apr 21, 2026
7fc978a
remove kwargs over closure conversion
unalmis Apr 21, 2026
8b33ca1
reduce duplicate code
unalmis Apr 22, 2026
5de9a9e
add missing todo
unalmis Apr 22, 2026
ae984ca
Remove bounce1d
unalmis Apr 22, 2026
ee71551
ad note
unalmis Apr 22, 2026
4da7446
.
unalmis Apr 22, 2026
f34abe2
missing exception
unalmis Apr 22, 2026
55155ed
missing label
unalmis Apr 22, 2026
b32bc40
Remove kwargs that are not needed anymore
unalmis Apr 22, 2026
1953444
clarify boolean
unalmis Apr 22, 2026
906b26a
.
unalmis Apr 22, 2026
f00647c
.
unalmis Apr 22, 2026
cd42371
.
unalmis Apr 22, 2026
088d5a2
clarify documentation
unalmis Apr 22, 2026
3fc45c2
fix closure conversion
unalmis Apr 22, 2026
98c9f1b
safer condition for compelx objs
unalmis Apr 22, 2026
566d464
fix pitch_batch_size subtlety
unalmis Apr 23, 2026
cdb9bf1
.
unalmis Apr 23, 2026
d8ec4c5
.
unalmis Apr 23, 2026
571c7d5
Merge branch 'master' into ku/test
unalmis Apr 23, 2026
c5fe484
.
unalmis Apr 23, 2026
7c721b8
Merge branch 'ku/test' into ku/sparse_pullback
unalmis Apr 23, 2026
adf73b5
fix comment
unalmis Apr 23, 2026
346938d
Switch resolution to per field period to simplify use and analysis (#…
unalmis Apr 25, 2026
e18dfe8
add missing default value
unalmis Apr 26, 2026
8748442
Resolves the fixme comment so that gradients are consistent (#2185)
unalmis Apr 26, 2026
d8868fe
push file into zip
unalmis Apr 27, 2026
8676dbd
Resolve remaining comments in #2147
unalmis Apr 28, 2026
ac79068
fix param
unalmis Apr 28, 2026
fdf80fe
Merge branch 'master' into ku/test
f0uriest Apr 29, 2026
3fe1120
Merge branch 'master' into ku/test
f0uriest Apr 29, 2026
bf95c79
Merge branch 'ku/test' into ku/sparse_pullback
unalmis Apr 29, 2026
18234f8
rory stuff
unalmis Apr 29, 2026
b1f5da7
rory stuff 2
unalmis Apr 29, 2026
06e22a6
.
unalmis Apr 29, 2026
f921c72
rory stuff 3
unalmis Apr 29, 2026
6af9f0f
reuse yb in comment to avoid confusion with nufft eps
unalmis Apr 29, 2026
2de27b0
Merge branch 'master' into ku/test
unalmis Apr 30, 2026
9a365e0
Merge branch 'ku/test' into ku/sparse_pullback
unalmis Apr 30, 2026
9b9d666
Merge branch 'master' into ku/test
unalmis May 1, 2026
1002c49
Merge branch 'ku/test' into ku/sparse_pullback
unalmis May 1, 2026
f729192
Merge branch 'master' into ku/sparse_pullback
unalmis May 4, 2026
b39ab31
Merge branch 'master' into ku/sparse_pullback
unalmis May 4, 2026
1fd2258
Merge branch 'master' into ku/sparse_pullback
unalmis May 11, 2026
81cb1a9
address rory
unalmis May 18, 2026
eaaaff7
Merge branch 'master' into ku/sparse_pullback
unalmis May 19, 2026
4955144
Merge branch 'master' into ku/sparse_pullback
unalmis May 19, 2026
930569d
@f0uriest
unalmis Jun 8, 2026
2808447
update files
unalmis Jun 16, 2026
95c0e1d
Merge remote-tracking branch 'upstream/master' into ku/sparse_pullback
unalmis Jun 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ New Features
- Sub-objectives of an `ObjectiveFunction` can now have different `use_jit` values than the `ObjectiveFunction`. These objectives have to be built before building the `ObjectiveFunction`.
- Adds ``num_neighbors`` parameter to ``CoilSetMinDistance`` that limits the pairwise distance computation to the nearest neighbors per coil, reducing memory useage for large coilsets.
- Method to plot frequency spectrum of inverse stream map in field line coordinates ``Bounce2D.plot_angle_spectrum``.
- Method to compute bounce integrals in batches is now added to the public API ``Bounce2D.batch``.
- Initiated deprecation of ``Bounce2D.compute_fieldline_length`` in favor of ``eq.compute("V_psi")``.
- The quadrature resolution in ``Bounce2D.compute_fieldline_length`` now corresponds to the resolution over a single field period instead of the resolution over a toroidal transit.
- Adds an optional attribute `ion_density` to the `Equilibrium` class, to allow the ion density profile to be set independently of the electron density and effective atomic number.
Expand All @@ -17,20 +16,29 @@ New Features
Bug Fixes

- Fixes SyntaxError thrown when loading hdf5 data from file-like objects.
- Fixes ``pitch_batch_size`` argument getting ignored in compute functions.
- Fixes a bug in `OmnigenousField.change_resolution` when changing `L_B`.
- Scaling a `ScaledProfile` or taking power of a `PowerProfile` now only updates the `scale`/`power` attributes instead of nesting the `ScaledProfile`/`PowerProfile`s.
- `jax.Array`s in `_static_attrs` will be automatically converted to `np.ndarray` to prevent stalling code. In general, jax arrays should be omitted in `_static_attrs`.

Performance Improvements

- Sparse reverse-mode differentiation was introduced to DESC
to yield significant performance improvements [#2170](https://github.com/PlasmaControl/DESC/pull/2170).
Plumbing to use this method was added to DESC that
will be progressively taken advantage of in the future.
- Reduces import time of `desc` modules.
- Now, `desc.compute._build_data_index` uses depth-first search algorithm to construct the dependency tree.
- Some of the default value computations at import time are removed (i.e. `desc.integrals.bounce_integral.default_quad`)
- [Significantly improves convergence of inverse stream maps](https://github.com/PlasmaControl/DESC/pull/1919).
- Check-pointing to bounce integrals to improve speed and reduce memory of reverse mode differentiation.
- Resolves a JAX memory regression in bounce integrals by avoiding materialization of a large tensor in memory. Previously, we had closed the issue by adding nuffts as a workaround. This update actually solves the issue for the case when a user specifies to not use nuffts as well.
- ``ObjectiveFunction.print_value`` can now use the previously computed ``compute_scaled_error`` values to print. For bounded objectives, we fall back to computing ``compute_unscaled``. Additionally, ``compute_scaled_error`` and array splitting are used in other parts of the code to prevent recompilation for one-time tasks, which makes initialization faster.

Breaking Changes

- The parameter ``num_transit`` in ``EffectiveRipple``, ``Gamma_c``, ``Bounce2D`` and related functions has been changed to ``num_field_periods``. This should make using a consistent resolution across different equilibria easier.
Comment thread
unalmis marked this conversation as resolved.
Outdated
- The parameter ``Y_B`` in ``EffectiveRipple``, ``Gamma_c``, ``Bounce2D`` is now the resolution over a single field period rather than a full toroidal transit. This should make using a consistent resolution across different equilibria easier.

Deprecations

- `constants` argument of `compute`, `jvp`, `jac`, `grad` and `hess` methods (including all of their variants) to all objective classes (including `ObjectiveFunction` and wrappers) is deprecated and will be removed in a future release. This argument was not necessary, and the code will still work if user doesn't pass it. Users should update their custom objectives for this change. In addition, `constants` property of the `ObjectiveFunction` and all sub-classes of `_Objective` is deprecated.
Expand Down
77 changes: 64 additions & 13 deletions desc/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ def vmap_chunked(
- https://github.com/jax-ml/jax/issues/26689
- https://github.com/jax-ml/jax/issues/27591
- https://github.com/jax-ml/jax/issues/31919
- Due to an actively worked on issue in JAX,
https://docs.jax.dev/en/latest/jep/
Comment thread
unalmis marked this conversation as resolved.
2026-custom-derivatives.html#main-problem-descriptions,
this function can simply ignore custom derivative rules
of the function in wraps if ``chunk_size`` is not ``None``,
and therefore can damp the effeciency gains of ``sparse_pullback``.
Use ``batch_map`` instead to avoid this,
or try to make a hack with jax.custom_transforms to bypass this.
- Only out axes = 0 is supported.

See Also
--------
batch_map
If the function supports native vectorization, use ``batch_map`` instead
for the reasons discussed the docstring.

Parameters
----------
Expand Down Expand Up @@ -231,20 +246,45 @@ def vmap_chunked(


def batch_map(
fun, fun_input, /, batch_size=None, *, reduction=None, chunk_reduction=identity
fun,
fun_input,
/,
batch_size=None,
*,
reduction=None,
chunk_reduction=identity,
strip_dim0=False,
):
"""Compute ``chunk_reduction(fun(fun_input))`` in batches.

This utility is like ``vmap_chunked`` except that ``fun`` is assumed to be
vectorized natively. No JAX vectorization such as ``vmap`` is applied to the
supplied function. This makes compilation faster and avoids the weaknesses of
applying JAX vectorization, such as executing all branches of code conditioned on
dynamic values. For example, this function would be useful for GitHub issue #1303.
Notes
-----
This method does not automatically wrap ``fun`` with ``vmap``.
Unless ``fun`` is already wrapped with ``vmap``, the leading dimension
of ``fun_input`` will not be stripped before it is passed into ``fun``.
This can be inconvenient for nesting calls to ``batch_map``,
since only batching along the first axis is supported.
However, the ``strip_dim0`` flag should cover the most common case
of nesting calls where ``batch_size`` is one on the outermost call.

If ``fun`` is natively vectorized, this can be preferable to ``vmap_chunked``
to reduce compilation time, avoid issues such as executing all branches of
code conditioned on dynamic values, or avoid messing up the behavior of
jvp's and vjp's under vmap, e.g.
https://docs.jax.dev/en/latest/jep/
2026-custom-derivatives.html#main-problem-descriptions.

Only out axes = 0 is supported.

See Also
--------
vmap_chunked
If the function does not support native vectorization.

Parameters
----------
fun : callable
Natively vectorized function.
Vectorized function.
fun_input : pytree
Data to split into batches to feed to ``fun``.
batch_size : int or None
Expand All @@ -257,19 +297,30 @@ def batch_map(
Chunk-wise reduction operation.
Should typically apply ``reduction`` along the mapped axis,
e.g. ``jnp.add.reduce``.
strip_dim0 : bool
Whether to strip the leading dim of ``fun_input`` before passing it
to ``fun``; see notes. This flag only works if ``batch_size`` is one.
It should be set to ``False`` if ``fun`` is wrapped in ``vmap``.
Default is ``False``.

Returns
-------
fun_output
Returns ``chunk_reduction(fun(fun_input))``.

"""
return (
chunk_reduction(fun(fun_input))
if batch_size is None
else _evaluate_in_chunks(
fun, batch_size, (0,), reduction, chunk_reduction, fun_input
)
if batch_size is None:
return chunk_reduction(fun(fun_input))
if strip_dim0 and batch_size == 1:
return _scanmap(fun, 0, reduction, identity)(fun_input)

return _evaluate_in_chunks(
fun,
batch_size,
(0,),
reduction,
chunk_reduction,
fun_input,
)


Expand Down
Loading
Loading