Skip to content

vectorized d1s#102

Open
jon-proximafusion wants to merge 4 commits into
developfrom
d1s-apply-time-correction-series
Open

vectorized d1s#102
jon-proximafusion wants to merge 4 commits into
developfrom
d1s-apply-time-correction-series

Conversation

@jon-proximafusion

@jon-proximafusion jon-proximafusion commented May 19, 2026

Copy link
Copy Markdown
Collaborator

Description

Adds apply_time_correction_series to openmc.deplete.d1s, a vectorized variant of apply_time_correction that evaluates many time indices in a single matrix multiplication.

Calling apply_time_correction in a loop over N time indices deep-copies the tally and re-multiplies its sum / sum_sq / mean / std_dev arrays N times. The new function builds an (N, n_radionuclides) factor matrix and folds the radionuclide-axis sum into a single matmul, so all indices are evaluated in one pass.

Returns NumPy arrays rather than a list of derived Tally objects: constructing N derived tallies (each with its own copy of _sum / _sum_sq / _mean / _std_dev) defeats the memory advantage on fine-mesh tallies. Users who need a Tally per index can build one from the returned arrays.

Motivation: shutdown dose-rate analysis routinely needs a full dose-vs-time curve, which means evaluating the same time_correction_factors dict at every index in the cooling schedule. For a 90-timestep schedule on a ~10⁶-voxel mesh the loop spends most of its time in repeated copy + elementwise multiply; the matmul-based path is ~5–15× faster on typical workloads (matmul hits BLAS, the per-iteration copy(tally) is gone, and _sum / _sum_sq are no longer materialized for results that are typically read-only).

Fixes # (issue) — N/A

Checklist

  • I have performed a self-review of my own code
  • I have run clang-format (version 18) on any C++ source files (if applicable)
  • I have followed the style guidelines for Python source files (if applicable)
  • I have made corresponding changes to the documentation (if applicable)
  • I have added tests that prove my fix is effective or that my feature works (if applicable)

@jon-proximafusion

Copy link
Copy Markdown
Collaborator Author

three plausible shapes, each with a different tradeoff:

(a) Polymorphic index: keep one function, accept int | Sequence[int]. Return Tally for scalar, list[Tally] for sequence. Cleanest API surface, but the return type flips
with the input type — upstream maintainers often push back on this in review because static-type users and help() readers can't tell what they're getting.

(b) Add indices kwarg: keep index as-is, add a new indices=None parameter. If supplied, vectorized path runs and returns list[Tally] (or arrays). Lower review risk than
(a), but two parameters that do almost the same job is also a smell.

(c) Two functions (what I did): clearest discoverability, can return ndarray from the series path without violating any existing contract — but more API surface.

this pr does (c)

@jon-proximafusion jon-proximafusion force-pushed the d1s-apply-time-correction-series branch from 891f85c to b9816fd Compare May 29, 2026 11:10
@jon-proximafusion

Copy link
Copy Markdown
Collaborator Author

Benchmark: vectorized apply_time_correction

The default sum_nuclides=True path now applies the time-correction as a single
np.einsum contraction over the parent-nuclide axis (instead of a
broadcast-multiply-and-reduce per index), and no longer materializes the
vestigial sum/sum_sq arrays (which were left in a shape inconsistent with the
popped ParentNuclideFilter). The factor matrix is shaped
(n_indices, n_radionuclides) so each index's row is contiguous, keeping a
scalar call bit-for-bit identical to the matching slice of a multi-index call.

Headline (mesh photon-flux tally, 27,000 spatial bins × 108 radionuclides × 201 times)

develop (per-index) : 5.36 s
new (vectorized)    : 0.61 s
speedup             : 8.9x       (mean/std_dev agree to 7.6e-16 relative)

Full sweep — bins × nuclides × timesteps × sum_nuclides

new is never slower than the previous implementation in any configuration.
bins is the total filter-bin count (spatial bins × parent radionuclides);
performance depends only on this count, not on whether bins come from a
MeshFilter or CellFilter. maxrel is the max relative difference in mean
vs. the previous implementation.

   bins  nuc  steps  sumNuc  develop(s)    new(s)  speedup    maxrel
     19   19      6    True      0.0001    0.0001     1.8x   2.1e-16
     19   19      6   False      0.0001    0.0001     1.6x   0.0e+00
     19   19    201    True      0.0038    0.0019     2.0x   2.6e-16
     19   19    201   False      0.0029    0.0015     1.9x   0.0e+00
    108  108      6    True      0.0002    0.0001     2.1x   2.2e-16
    108  108      6   False      0.0002    0.0001     2.0x   0.0e+00
    108  108    201    True      0.0077    0.0031     2.5x   6.8e-16
    108  108    201   False      0.0070    0.0027     2.6x   0.0e+00
   2375   19      6    True      0.0002    0.0001     2.6x   4.3e-16
   2375   19      6   False      0.0001    0.0001     1.4x   0.0e+00
   2375   19    201    True      0.0064    0.0023     2.7x   4.5e-16
   2375   19    201   False      0.0047    0.0032     1.5x   0.0e+00
  13500  108      6    True      0.0005    0.0001     3.2x   7.9e-16
  13500  108      6   False      0.0004    0.0003     1.4x   0.0e+00
  13500  108    201    True      0.0170    0.0043     4.0x   7.0e-16
  13500  108    201   False      0.0327    0.0290     1.1x   0.0e+00
 513000   19      6    True      0.0356    0.0022    16.5x   4.2e-16
 513000   19      6   False      0.0108    0.0106     1.0x   0.0e+00
 513000   19    201    True      1.1113    0.0867    12.8x   4.4e-16
 513000   19    201   False      1.1697    1.2227     1.0x   0.0e+00
2916000  108      6    True      0.2131    0.0185    11.5x   6.5e-16
2916000  108      6   False      0.0603    0.0625     1.0x   0.0e+00
2916000  108    201    True      4.4520    0.5667     7.9x   8.0e-16
2916000  108    201   False      5.5450    5.4353     1.0x   0.0e+00

Notes:

  • sum_nuclides=True (the default, shutdown-dose-rate case): 1.8–16.5× faster
    across every size, with results matching to ~1e-15 relative (machine epsilon,
    far below Monte Carlo statistical noise).
  • sum_nuclides=False: results are bit-for-bit identical (maxrel = 0).
    Speedup is 1.4–2.6× for small/medium tallies and ~parity for the very largest,
    where the full per-nuclide corrected output is irreducible work. The few
    apparent <5% "slowdowns" in the largest False rows are memory-bandwidth
    measurement noise — re-measured cleanly with interleaved A/B reps (median of
    15), new/develop = 0.996 (slightly faster). Never slower.
Benchmark script (d1s_vectorized_benchmark.py)
"""D1S mesh-tally benchmark: new vectorized apply_time_correction vs. develop.

Builds a minimal D1S model with a mesh photon-flux tally and many timesteps,
then applies the time-correction factors at every time two ways:

  * ``develop`` algorithm (a faithful inline copy of the per-index code on the
    ``develop`` branch), called once per time index, and
  * the new ``openmc.deplete.d1s.apply_time_correction`` with a list of indices.

It reports the speedup and asserts the two agree (mean and std_dev) to within
floating-point round-off -- the new code contracts over the parent-nuclide axis
with ``einsum`` instead of a broadcast-multiply-then-sum, so results match to
~1e-13 relative (far below Monte Carlo statistical noise) rather than bitwise.

Run inside the project venv with nuclear data configured:
    OPENMC_CROSS_SECTIONS=~/nuclear_data/cross_sections.xml python d1s_vectorized_benchmark.py
"""
import time
from copy import copy
from math import prod
from pathlib import Path

import numpy as np
import openmc
from openmc.deplete import d1s

CHAIN_FILE = Path("~/nuclear_data/chain_endf_b8.0.xml").expanduser()


def apply_develop(tally, time_correction_factors, index=-1, sum_nuclides=True):
    """Faithful copy of apply_time_correction from the ``develop`` branch."""
    for i_filter, filt in enumerate(tally.filters):
        if isinstance(filt, openmc.ParentNuclideFilter):
            break
    else:
        raise ValueError('Tally must contain a ParentNuclideFilter')

    radionuclides = [str(x) for x in tally.filters[i_filter].bins]
    tcf = np.array([time_correction_factors[x][index] for x in radionuclides])
    tally.std_dev

    new_tally = copy(tally)
    new_tally._filters = copy(tally._filters)
    n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]])
    n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]])
    _, n_nuclides, n_scores = new_tally.shape
    n_radionuclides = len(radionuclides)
    shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
    tally_sum = new_tally.sum.reshape(shape)
    tally_sum_sq = new_tally.sum_sq.reshape(shape)
    tally_mean = new_tally.mean.reshape(shape)
    tally_std_dev = new_tally.std_dev.reshape(shape)

    tcf.shape = (1, -1, 1, 1, 1)
    new_tally._sum = tally_sum * tcf
    new_tally._sum_sq = tally_sum_sq * (tcf * tcf)
    new_tally._mean = tally_mean * tcf
    new_tally._std_dev = tally_std_dev * tcf

    shape = (-1, n_nuclides, n_scores)
    if sum_nuclides:
        new_tally._mean = new_tally.mean.sum(axis=1).reshape(shape)
        new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(shape)
        new_tally._derived = True
        new_tally.filters.pop(i_filter)
    else:
        new_tally._sum.shape = shape
        new_tally._sum_sq.shape = shape
        new_tally._mean.shape = shape
        new_tally._std_dev.shape = shape
    return new_tally


# ---------------------------------------------------------------------------
# Minimal activated geometry: nickel-bearing alloy sphere, 14 MeV point source
# ---------------------------------------------------------------------------
mat = openmc.Material()
for el, w in [("Fe", 0.6), ("Cr", 0.18), ("Ni", 0.1), ("Mo", 0.03),
              ("Mn", 0.02), ("Co", 0.01), ("Cu", 0.01), ("Ti", 0.01),
              ("Nb", 0.01), ("W", 0.01), ("Ta", 0.01), ("Zr", 0.01)]:
    mat.add_element(el, w)
mat.set_density("g/cm3", 8.0)

sphere = openmc.Sphere(r=10.0, boundary_type="vacuum")
cell = openmc.Cell(fill=mat, region=-sphere)

model = openmc.Model()
model.geometry = openmc.Geometry([cell])
model.settings.run_mode = "fixed source"
model.settings.batches = 10
model.settings.particles = 2000
model.settings.photon_transport = True
model.settings.use_decay_photons = True  # D1S: emit decay photons during transport
model.settings.source = openmc.IndependentSource(
    space=openmc.stats.Point((0.0, 0.0, 0.0)),
    energy=openmc.stats.Discrete([14.0e6], [1.0]),
    particle="neutron",
)

mesh = openmc.RegularMesh()
mesh.dimension = (30, 30, 30)          # 27,000 spatial bins
mesh.lower_left = (-10.0, -10.0, -10.0)
mesh.upper_right = (10.0, 10.0, 10.0)

tally = openmc.Tally(name="photon_flux_mesh") 
tally.filters = [openmc.MeshFilter(mesh), openmc.ParticleFilter("photon")]
tally.scores = ["flux"]
model.tallies = [tally]

# ---------------------------------------------------------------------------
# D1S setup
# ---------------------------------------------------------------------------
with openmc.config.patch("chain_file", CHAIN_FILE):
    nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_FILE)

    n_steps = 200
    timesteps = [3600.0] * n_steps          # 1 h each
    source_rates = [1.0e10] * (n_steps // 2) + [0.0] * (n_steps - n_steps // 2)
    factors = d1s.time_correction_factors(nuclides, timesteps, source_rates)
    n_times = len(factors[nuclides[0]])     # n_steps + 1 times to choose from

    print(f"Parent radionuclides: {len(nuclides)}") 
    print(f"Mesh bins: {int(np.prod(mesh.dimension))}, times available: {n_times}")

    sp_path = model.run(output=False)

# ---------------------------------------------------------------------------
# Apply the correction for every available time, two ways, and compare
# ---------------------------------------------------------------------------
with openmc.StatePoint(sp_path) as sp:
    result_tally = sp.tallies[tally.id]
    _ = result_tally.std_dev  # warm the mean/std_dev cache so both paths are fair
    indices = list(range(n_times))

    # (A) develop behavior: one scalar call per timestep
    t0 = time.perf_counter()
    develop_results = [
        apply_develop(result_tally, factors, index=i, sum_nuclides=True)
        for i in indices
    ]
    t_develop = time.perf_counter() - t0

    # (B) new behavior: a single vectorized call with all indices
    t0 = time.perf_counter()
    new_results = d1s.apply_time_correction(
        result_tally, factors, index=indices, sum_nuclides=True
    )
    t_new = time.perf_counter() - t0

# ---------------------------------------------------------------------------
# Verify results agree
# ---------------------------------------------------------------------------
assert isinstance(new_results, list) and len(new_results) == n_times
max_rel = 0.0
for old, new in zip(develop_results, new_results):
    np.testing.assert_allclose(new.mean, old.mean, rtol=1e-12, atol=0)
    np.testing.assert_allclose(new.std_dev, old.std_dev, rtol=1e-12, atol=0)
    assert new.filters == old.filters
    denom = np.where(old.mean != 0, np.abs(old.mean), 1.0)
    max_rel = max(max_rel, np.max(np.abs(new.mean - old.mean) / denom))

print(f"\nResults agree across all {n_times} times "
      f"(max relative mean diff = {max_rel:.2e})")
print(f"develop (per-index) : {t_develop:.4f} s")
print(f"new (vectorized)    : {t_new:.4f} s")
print(f"speedup             : {t_develop / t_new:.1f}x")

Rework the sum_nuclides=True path of apply_time_correction so the
TCF-weighted sum over the parent-nuclide axis is evaluated as a
contraction (np.einsum) rather than a broadcast-multiply-and-reduce per
index. The shared 5-D tally views are reshaped once.

For a summed (derived) tally the public sum/sum_sq accessors return None
regardless of the stored arrays, so the derived tally's sum/sum_sq are
left unset rather than recomputed each call: this matches develop's
observable behavior, skips two full-array multiplies per index, and
avoids storing arrays shaped inconsistently with the popped
ParentNuclideFilter (which break Tally.sparse).

For a mesh tally (27k bins x 108 radionuclides x ~200 times) this is ~9x
faster than the per-index implementation, with mean/std_dev agreeing to
~1e-15 relative. The factor matrix is shaped (n_indices, n_radionuclides)
so each index's row is contiguous, keeping a scalar call bit-for-bit
identical to the matching slice of a multi-index call. Update the
docstring/comments and extend the multi-index unit test.
@jon-proximafusion jon-proximafusion force-pushed the d1s-apply-time-correction-series branch from b9816fd to b788d20 Compare May 29, 2026 11:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants