Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
788d621
Use backward stable algorithm with correction for infinite condition …
unalmis Apr 30, 2026
2078288
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis Apr 30, 2026
892699a
Add back analytical formulae because much faster
unalmis Apr 30, 2026
158203f
Fix typo in comment for root multiplicity check
unalmis May 1, 2026
10c78ed
.
unalmis May 1, 2026
342856e
simplify formulae
unalmis May 1, 2026
26463f9
make sure types remain real
unalmis May 1, 2026
e3bd373
better documentation
unalmis May 1, 2026
bf939bb
.
unalmis May 1, 2026
c9a2469
.
unalmis May 1, 2026
61a41f9
.
unalmis May 1, 2026
28ec657
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 1, 2026
c9cbff1
better filtering
unalmis May 2, 2026
95f11fa
.
unalmis May 2, 2026
205b1b4
.
unalmis May 2, 2026
8d301ff
.
unalmis May 2, 2026
97bf192
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 4, 2026
74895a1
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 4, 2026
b2426ad
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 11, 2026
593a8a5
.
unalmis May 15, 2026
68b0f26
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 18, 2026
54a204d
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 19, 2026
7fa245e
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis May 19, 2026
5e17736
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis Jun 8, 2026
21704b1
Merge branch 'ku/sparse_pullback' into ku/condition_number
unalmis Jun 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions desc/integrals/_bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from desc.backend import dct, ifft, jax, jnp
from desc.integrals._interp_utils import (
_JF_BUG,
_eps,
_root_eps,
chebder,
nufft1d2r,
nufft2d2r,
Expand Down Expand Up @@ -370,10 +370,11 @@ def bounce_points_jvp(num_well, primals, tangents):
dB_dz += dB_dt * dt_dz
dB_do += dB_dt * dt_do

regularization = _root_eps()
dB_dz = jnp.where(
jnp.abs(dB_dz) > _eps,
jnp.abs(dB_dz) > regularization,
dB_dz,
dB_dz + jnp.copysign(_eps, dB_dz.real),
dB_dz + jnp.copysign(regularization, dB_dz.real),
)
dz = jnp.where(mask, (dp[..., None] - dB_do) / dB_dz, 0.0)

Expand Down Expand Up @@ -702,6 +703,7 @@ def get_mins(knots, B, num_mins=-1, fill_value=0.0):
a_min=jnp.array([0.0]),
a_max=jnp.diff(knots),
sentinel=0.0,
distinct=False,
)
b = flatten_mat((poly_val(x=mins, c=b[..., None, :], der=True) > 0) & (mins > 0))
mins = flatten_mat(
Expand Down
215 changes: 118 additions & 97 deletions desc/integrals/_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,15 @@
)


def _filter_distinct(r, sentinel, eps):
"""Set all but one of matching adjacent elements in ``r`` to ``sentinel``."""
# eps needs to be low enough that close distinct roots do not get removed.
# Otherwise, algorithms relying on continuity will fail.
mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps)
return jnp.where(mask, sentinel, r)


_root_companion = jnp.vectorize(
partial(jnp.roots, strip_zeros=False), signature="(m)->(n)"
)
_eps = max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12)


def _root_eps():
# Safer to make this a callable since output depends on whether
# double precision is enabled before it is called.
return max(jnp.finfo(jnp.array(1.0).dtype).eps, 1e-11)


@partial(jax.custom_jvp, nondiff_argnums=(4, 5, 6, 7))
Expand All @@ -279,7 +276,7 @@
a_max=None,
sort=False,
sentinel=jnp.nan,
eps=_eps,
eps=-1.0,
distinct=False,
):
"""Roots of polynomial with given coefficients.
Expand Down Expand Up @@ -320,44 +317,89 @@
The roots of the polynomial, iterated over the last axis.

"""
if eps < 0:
eps = _root_eps()
get_only_real_roots = not (a_min is None and a_max is None)
num_coef = c.shape[-1]
distinct = distinct and num_coef > 2
func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic}

if (
num_coef in func
and get_only_real_roots
and jnp.isrealobj(c)
and jnp.isrealobj(k)
):
# Compute from analytic formula to avoid the issue of complex roots with small
# imaginary parts. Also consumes less memory.
degree = c.shape[-1] - 1
distinct = distinct and (degree > 1)

if degree <= 3 and get_only_real_roots and jnp.isrealobj(c) and jnp.isrealobj(k):
backward_stable = degree < 3

c = jnp.moveaxis(c, -1, 0)
r = func[num_coef](*c[:-1], c[-1] - k, sentinel, eps, distinct)
if num_coef == 2:
r = r[jnp.newaxis]
r = {1: _root_linear, 2: _root_quadratic, 3: _root_cubic}[degree](
*c[:-1], c[-1] - k
)
r = jnp.moveaxis(r, 0, -1)

# We already filtered distinct roots for quadratics.
distinct = distinct and num_coef > 3
c = jnp.moveaxis(c, 0, -1)
else:
backward_stable = True

r = _root_companion(_subtract_last(c, k))
# If the complex part is too big, then these would not be real roots of
# a nearby perturbed problem, so we set to nan so that they are not
# classified as candidates after the correction step.
if get_only_real_roots:
r = jnp.where(

Check warning on line 343 in desc/integrals/_interp_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/_interp_utils.py#L343

Added line #L343 was not covered by tests
jnp.abs(r.imag) <= eps**0.5,
r.real,
jnp.nan,
)

# Schröder first kind correction to push the roots of the perturbed problem
# toward roots of the original problem.
if degree > 1:
k = jnp.expand_dims(k, -1)
c = c[..., None, :]
p0 = poly_val(x=r, c=c) - k
p1 = poly_val(x=r, c=c, der=True)
p2 = poly_val(x=r, c=c[..., :-1] * jnp.arange(degree, 0, -1), der=True)
candidate = r - (p0 * p1) / (p1**2 - p0 * p2)
p0 = jnp.abs(p0) # residual
p0_new = jnp.abs(poly_val(x=candidate, c=c) - k)
r = jnp.where(p0_new < p0, candidate, r)
if not backward_stable:
r = jnp.where(
jnp.minimum(p0_new, p0) <= eps**0.5,
r,
jnp.nan,
)

if distinct:
# Then we need to ensure the return roots have a consistent multiplicity,
# so we discard roots within a few machine precision of extrema. This is
# mathematically justified as they could just as well have not been detected
# in the nearby perturbed problem. Application purpose is that,
# if we return only one root at a given extrema (e.g. if we only found one
# since the root of the perturbed problem was single multiplicity) and return
# multiple roots for another extrema (e.g. if we found 2 or 3 in the nearby
# perturbed problem which were merged by the correction step), then downstream
# algorithms which assume continuity (intermediate value theorem) of the
# underlying function will be misled.
r = jnp.where(
jnp.abs(poly_val(x=r, c=c, der=True)) > eps,
r,
sentinel,
)

if get_only_real_roots:
a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis]
a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis]
r = jnp.where(
(jnp.abs(r.imag) <= eps) & (a_min <= r.real) & (r.real <= a_max),
r.real,
(a_min <= r) & (r <= a_max),
r,
sentinel,
)

if sort or distinct:
r = jnp.sort(r, axis=-1)
r = jnp.sort(r, stable=False)
if distinct:
r = _filter_distinct(r, sentinel, eps)
assert r.shape[-1] == num_coef - 1
r = jnp.where(
jnp.isclose(jnp.diff(r, prepend=sentinel), 0.0, atol=eps),
sentinel,
r,
)
assert r.shape[-1] == degree
return r


Expand All @@ -377,6 +419,8 @@
c, k, a_min, a_max = primals
dc, dk, _, _ = tangents

if eps < 0:
eps = _root_eps()
r = polyroot_vec(c, k, a_min, a_max, sort, sentinel, eps, distinct)

dc_dr = poly_val(x=r, c=c[..., None, :], der=True)
Expand All @@ -393,83 +437,60 @@
return r, dr


def _root_cubic(a, b, c, d, sentinel, eps, distinct):
"""Return real cubic root assuming real coefficients."""
# numerical.recipes/book.html, page 228

def irreducible(Q, R, b):
# Three irrational real roots.
theta = jnp.arccos(R / jnp.sqrt(Q**3))
return (
-2
* jnp.sqrt(Q)
* jnp.stack(
[
jnp.cos(theta / 3),
jnp.cos((theta + 2 * jnp.pi) / 3),
jnp.cos((theta - 2 * jnp.pi) / 3),
]
)
- b / 3
def _irreducible(Q, R, b):
# Three irrational real roots.
theta = jnp.arccos(R / jnp.sqrt(Q**3))
return (
-2
* jnp.sqrt(Q)
* jnp.stack(
[
jnp.cos(theta / 3),
jnp.cos((theta + 2 * jnp.pi) / 3),
jnp.cos((theta - 2 * jnp.pi) / 3),
]
)
- b / 3
)

def reducible(Q, R, b):
# One real and two complex roots.
A = -jnp.sign(R) * jnp.cbrt(jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3)))
B = Q / A
r1 = (A + B) - b / 3
return _concat_sentinel(r1[jnp.newaxis], sentinel, num=2)

def root(b, c, d):
b = b / a
c = c / a
Q = (b**2 - 3 * c) / 9
R = (2 * b**3 - 9 * b * c) / 54 + d / (2 * a)
return jnp.where(
R**2 < Q**3,
irreducible(jnp.abs(Q), R, b),
reducible(Q, R, b),
)

return jnp.where(
# Tests catch failure here if eps < 1e-12 for double precision.
jnp.abs(a) <= eps,
_concat_sentinel(
_root_quadratic(b, c, d, sentinel, eps, distinct),
sentinel,
),
root(b, c, d),
)
def _reducible(Q, R, b):
# One real and two complex roots.
A = -jnp.sign(R) * jnp.cbrt(jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3)))
B = Q / A
r = ((A + B) - b / 3)[None]
r = jnp.concatenate((r, jnp.broadcast_to(jnp.nan, (2, *r.shape[1:]))))
return r


def _root_cubic(a, b, c, d):
"""Return real cubic root assuming real coefficients.

def _root_quadratic(a, b, c, sentinel, eps, distinct):
Uses numerical.recipes/book.html, page 228, which is not backwards stable.
This can generate fake root with O(1) residual, so post-processing is needed.
Advantage is it is much more performant than eigenvalue solve, especially
when d is higher dimensional than a, b, c.
"""
b = b / a
c = c / a
Q = (b**2 - 3 * c) / 9
R = (2 * b**3 - 9 * b * c) / 54 + d / (2 * a)
return jnp.where(R**2 < Q**3, _irreducible(jnp.abs(Q), R, b), _reducible(Q, R, b))


def _root_quadratic(a, b, c):
"""Return real quadratic root assuming real coefficients."""
# numerical.recipes/book.html, page 227

discriminant = b**2 - 4 * a * c
q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant)))
r1 = jnp.where(
discriminant < 0,
sentinel,
jnp.where(a == 0, _root_linear(b, c, sentinel, eps), q / a),
)
r2 = jnp.where(
# more robust to remove repeated roots with discriminant
(discriminant < 0) | (distinct & (discriminant <= eps)),
sentinel,
c / q,
)
r1 = jnp.where(discriminant >= 0, q / a, jnp.nan)
r2 = jnp.where(discriminant >= 0, c / q, jnp.nan)
return jnp.stack([r1, r2])


def _root_linear(a, b, sentinel, eps, distinct=False):
def _root_linear(a, b, sentinel):
"""Return real linear root assuming real coefficients."""
return jnp.where((a == 0) & (jnp.abs(b) <= eps), 0.0, -b / a)


def _concat_sentinel(r, sentinel, num=1):
"""Concatenate ``sentinel`` ``num`` times to ``r`` on first axis."""
return jnp.concatenate((r, jnp.broadcast_to(sentinel, (num, *r.shape[1:]))))
return (-b / a)[None]

Check warning on line 493 in desc/integrals/_interp_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/_interp_utils.py#L493

Added line #L493 was not covered by tests


# TODO: replace the inner loop in orthax with this
Expand Down
16 changes: 6 additions & 10 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from desc.integrals._interp_utils import (
_JF_BUG,
_eps,
_root_eps,
interp1d_Hermite_vec,
interp1d_vec,
nufft2d2r,
Expand Down Expand Up @@ -106,11 +106,6 @@ def pitch_quad(min_B, max_B, num_pitch, **kwargs):

"""
if isinstance(num_pitch, int):
errorif(
num_pitch > 1e5,
msg="Floating point error impedes detection of bounce points "
f"near global extrema. Choose {num_pitch} < 1e5.",
)
simp = kwargs.get("simp", True)
num_pitch = simpson2(num_pitch) if simp else uniform(num_pitch)

Expand Down Expand Up @@ -674,7 +669,7 @@ def points(self, pitch_inv, num_well=None):
# size of 10⁻⁸ is reduced from 10³ to 10⁻⁶ for Γ_c
# (which has the singular weight 1/v_∥).
z1, z2 = self._B.intersect1d(
self._swap_axes(pitch_inv), num_intersect=num_well, eps=_eps
self._swap_axes(pitch_inv), num_intersect=num_well, eps=_root_eps()
)
z1 = move(z1)
z2 = move(z2)
Expand Down Expand Up @@ -983,7 +978,7 @@ def interp_to_argmin(self, f, points, *, nufft_eps=-1.0, **kwargs):
# such that all bounce points are at ζ >= 0; and therefore,
# junk values in B_mins cannot be selected in argmin.
mins, B_mins = (
self._B.extrema1d(1, num_mins, fill_value=0.0, eps=_eps)
self._B.extrema1d(1, num_mins, fill_value=0.0, eps=_root_eps())
if isinstance(self._B, PiecewiseChebyshevSeries)
else get_mins(self._c["knots"], self._B, num_mins, fill_value=0.0)
)
Expand Down Expand Up @@ -1111,7 +1106,7 @@ def plot(self, l, m, pitch_inv=None, **kwargs):
B = B[m]
B = PiecewiseChebyshevSeries(B, domain)
if pitch_inv is not None:
kwargs["z1"], kwargs["z2"] = B.intersect1d(pitch_inv, eps=_eps)
kwargs["z1"], kwargs["z2"] = B.intersect1d(pitch_inv, eps=_root_eps())
kwargs["k"] = pitch_inv
return B.plot1d(B.cheb, **kwargs)

Expand Down Expand Up @@ -1907,7 +1902,8 @@ def guess(

"""
errorif(
(surf_batch_size > 1) and (pitch_batch_size is not None),
(surf_batch_size is None or surf_batch_size > 1)
and (pitch_batch_size is not None),
msg=f"Expected pitch_batch_size to be None, got {pitch_batch_size}.",
)

Expand Down
Binary file modified tests/inputs/master_compute_data_rpz.pkl
Comment thread
unalmis marked this conversation as resolved.
Binary file not shown.
5 changes: 2 additions & 3 deletions tests/test_compute_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ def fft_grid_data(p):
fft_names = ["effective ripple", "Gamma_c", "Gamma_c Velasco"]

eq = get("W7-X")
# ci and my laptop differ a bunch at rho = 0, so skip that
rho = np.linspace(1e-2, 1, 10)
rho = np.linspace(0, 1, 10)
grid = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=False)

nufft_eps = 1e-10
Expand Down Expand Up @@ -367,7 +366,7 @@ def raz_grid_data(p):
eq = get("W7-X")
num_transit = 2
Y_B = eq.N_grid * 2
rho = np.linspace(1e-2, 1, 10)
rho = np.linspace(0, 1, 10)
alpha = np.array([0])
zeta = np.linspace(0, num_transit * 2 * np.pi, num_transit * Y_B * eq.NFP)
grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz")
Expand Down
Loading