Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/usersguide/decay_sources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ relevant tallies. This can be done with the aid of the
with openmc.StatePoint(output_path) as sp:
dose_tally = sp.get_tally(name='dose tally')

# Apply time correction factors
tally = d1s.apply_time_correction(dose_tally, factors, time_index)
# Apply time correction factors at one or more time indices. A list of
# derived tallies is returned, one per index.
tally, = d1s.apply_time_correction(dose_tally, factors, [time_index])

120 changes: 68 additions & 52 deletions openmc/deplete/d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,43 @@ def time_correction_factors(
def apply_time_correction(
tally: openmc.Tally,
time_correction_factors: dict[str, np.ndarray],
index: int = -1,
index: Sequence[int] = (-1,),
sum_nuclides: bool = True
) -> openmc.Tally:
) -> list[openmc.Tally]:
"""Apply time correction factors to a tally.

This function applies the time correction factors at the given index to a
tally that contains a :class:`~openmc.ParentNuclideFilter`. When
`sum_nuclides` is True, values over all parent nuclides will be summed,
leaving a single value for each filter combination.
This function applies the time correction factors at the given indices to a
tally that contains a :class:`~openmc.ParentNuclideFilter`, returning one
derived tally per index. When `sum_nuclides` is True, values over all parent
nuclides will be summed, leaving a single value for each filter combination.

.. versionchanged:: 0.16.0
`index` now takes a sequence of indices and the function returns a list
of tallies (one per index) rather than a single tally.

Parameters
----------
tally : openmc.Tally
Tally to apply the time correction factors to
time_correction_factors : dict
Time correction factors as returned by :func:`time_correction_factors`
index : int, optional
Index of the time of interest. If N timesteps are provided in
index : iterable of int, optional
Indices of the times of interest. If N timesteps are provided in
:func:`time_correction_factors`, there are N + 1 times to select from.
The default is -1 which corresponds to the final time.
The default is ``(-1,)`` which corresponds to the final time. The tally
arrays are read and reshaped once and shared across all indices, so
evaluating many times (e.g. for a D1S mesh tally) is much cheaper than
repeated single-index calls.
Comment thread
jon-proximafusion marked this conversation as resolved.
Outdated
sum_nuclides : bool
Whether to sum over the parent nuclides

Returns
-------
openmc.Tally
Derived tally with time correction factors applied
list of openmc.Tally
Derived tallies with time correction factors applied, one per entry in
`index` and in the same order. When `sum_nuclides` is True each result
is a derived tally, for which `sum` and `sum_sq` are None; the
meaningful results are `mean` and `std_dev`.

"""
# Make sure the tally contains a ParentNuclideFilter
Expand All @@ -160,57 +170,63 @@ def apply_time_correction(
else:
raise ValueError('Tally must contain a ParentNuclideFilter')

indices = list(index)

# Get list of radionuclides based on tally filter
radionuclides = [str(x) for x in tally.filters[i_filter].bins]
tcf = np.array([time_correction_factors[x][index] for x in radionuclides])

# Force tally results to be read and std_dev to be computed
# Force tally results to be read and std_dev to be computed (once)
tally.std_dev

# Create shallow copy of tally
new_tally = copy(tally)
new_tally._filters = copy(tally._filters)

# Determine number of bins in other filters
n_bins_before = prod([f.num_bins for f in tally.filters[:i_filter]])
n_bins_after = prod([f.num_bins for f in tally.filters[i_filter + 1:]])

# Reshape sum and sum_sq, apply TCF, and sum along that axis
_, n_nuclides, n_scores = new_tally.shape
_, n_nuclides, n_scores = tally.shape
n_radionuclides = len(radionuclides)
shape = (n_bins_before, n_radionuclides, n_bins_after, n_nuclides, n_scores)
tally_sum = new_tally.sum.reshape(shape)
tally_sum_sq = new_tally.sum_sq.reshape(shape)
tally_mean = new_tally.mean.reshape(shape)
tally_std_dev = new_tally.std_dev.reshape(shape)

# Apply TCF, broadcasting to the correct dimensions
tcf.shape = (1, -1, 1, 1, 1)
new_tally._mean = tally_mean * tcf
new_tally._std_dev = tally_std_dev * tcf

shape = (-1, n_nuclides, n_scores)

if sum_nuclides:
# Sum over parent nuclides (note that when combining different bins for
# parent nuclide, we can't work directly on sum_sq)
new_tally._sum = None
new_tally._sum_sq = None
new_tally._mean = new_tally.mean.sum(axis=1).reshape(shape)
new_tally._std_dev = np.linalg.norm(new_tally.std_dev, axis=1).reshape(shape)
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
else:
# Apply TCF and change shape back to (filter combinations, nuclides,
# scores)
new_tally._sum = (tally_sum * tcf).reshape(shape)
new_tally._sum_sq = (tally_sum_sq * (tcf*tcf)).reshape(shape)
new_tally._mean.shape = shape
new_tally._std_dev.shape = shape

return new_tally
flat_shape = (-1, n_nuclides, n_scores)

# Reshape the tally arrays once and reuse them for every index
tally_sum = tally.sum.reshape(shape)
tally_sum_sq = tally.sum_sq.reshape(shape)
tally_mean = tally.mean.reshape(shape)
tally_std_dev = tally.std_dev.reshape(shape)

results = []
for idx in indices:
tcf = np.array([time_correction_factors[x][idx] for x in radionuclides])

# Apply TCF, broadcasting to the correct dimensions
tcf.shape = (1, -1, 1, 1, 1)
mean = tally_mean * tcf
std_dev = tally_std_dev * tcf

# Create shallow copy of tally
new_tally = copy(tally)
new_tally._filters = copy(tally._filters)

if sum_nuclides:
# Sum over parent nuclides (note that when combining different bins
# for parent nuclide, we can't work directly on sum_sq)
new_tally._sum = None
new_tally._sum_sq = None
new_tally._mean = mean.sum(axis=1).reshape(flat_shape)
new_tally._std_dev = np.linalg.norm(std_dev, axis=1).reshape(flat_shape)
new_tally._derived = True

# Remove ParentNuclideFilter
new_tally.filters.pop(i_filter)
else:
# Apply TCF and change shape back to (filter combinations, nuclides,
# scores)
new_tally._sum = (tally_sum * tcf).reshape(flat_shape)
new_tally._sum_sq = (tally_sum_sq * (tcf*tcf)).reshape(flat_shape)
new_tally._mean = mean.reshape(flat_shape)
new_tally._std_dev = std_dev.reshape(flat_shape)

results.append(new_tally)

return results


def prepare_tallies(
Expand Down
82 changes: 79 additions & 3 deletions tests/unit_tests/test_d1s.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ def test_apply_time_correction(run_in_tmpdir):
tally_mean = tally.mean.copy()
tally_std_dev = tally.std_dev.copy()

# Apply TCF and make sure results are consistent
result = d1s.apply_time_correction(tally, factors, sum_nuclides=False)
# Apply TCF and make sure results are consistent (a single index returns a
# one-element list)
result, = d1s.apply_time_correction(tally, factors, sum_nuclides=False)
tcf = np.array([factors[nuc][-1] for nuc in nuclides])
assert result.mean.flatten() == pytest.approx(tcf * flux)

# Make sure summed results match a manual sum
result_summed = d1s.apply_time_correction(tally, factors)
result_summed, = d1s.apply_time_correction(tally, factors)
assert result_summed.mean.flatten()[0] == pytest.approx(result.mean.sum())

# Make sure original tally is unchanged
Expand All @@ -154,3 +155,78 @@ def test_apply_time_correction(run_in_tmpdir):
# The summed tally is derived, so sum/sum_sq are None
assert result_summed.sum is None
assert result_summed.sum_sq is None


def test_apply_time_correction_multi_index(run_in_tmpdir):
# Build the same model used in test_apply_time_correction
mat = openmc.Material()
mat.add_element('Ni', 1.0)
sphere = openmc.Sphere(r=10.0, boundary_type='vacuum')
cell = openmc.Cell(fill=mat, region=-sphere)
model = openmc.Model()
model.geometry = openmc.Geometry([cell])
model.settings.run_mode = 'fixed source'
model.settings.batches = 3
model.settings.particles = 10
model.settings.photon_transport = True
model.settings.use_decay_photons = True
particle_filter = openmc.ParticleFilter('photon')
tally = openmc.Tally()
tally.filters = [particle_filter]
tally.scores = ['flux']
model.tallies = [tally]

# A schedule with several timesteps so we can ask for many indices
nuclides = d1s.prepare_tallies(model, chain_file=CHAIN_PATH)
timesteps = [1.0e8, 1.0e8, 1.0e8, 1.0e8]
source_rates = [1.0, 0.0, 1.0, 0.0]
factors = d1s.time_correction_factors(nuclides, timesteps, source_rates)
n_times = len(factors[nuclides[0]])

with openmc.config.patch('chain_file', CHAIN_PATH):
output_path = model.run()
with openmc.StatePoint(output_path) as sp:
tally = sp.tallies[tally.id]

orig_filters = list(tally.filters)
orig_sum = tally.sum.copy()
orig_sum_sq = tally.sum_sq.copy()
orig_mean = tally.mean.copy()
orig_std_dev = tally.std_dev.copy()

# A multi-index call returns one derived tally per index, each matching
# the corresponding single-index call.
for sum_nuc in (True, False):
many = d1s.apply_time_correction(
tally, factors, index=range(n_times), sum_nuclides=sum_nuc,
)
assert len(many) == n_times
for i, derived in enumerate(many):
ref, = d1s.apply_time_correction(
tally, factors, index=[i], sum_nuclides=sum_nuc
)
np.testing.assert_array_equal(derived.mean, ref.mean)
np.testing.assert_array_equal(derived.std_dev, ref.std_dev)
assert derived.filters == ref.filters
if sum_nuc:
# Summed tally is derived, so sum/sum_sq are None
assert derived.sum is None and derived.sum_sq is None
else:
np.testing.assert_array_equal(derived.sum, ref.sum)
np.testing.assert_array_equal(derived.sum_sq, ref.sum_sq)

# An unordered / partial index sequence is honored in order
subset = [n_times - 1, 0, 2]
many = d1s.apply_time_correction(tally, factors, index=subset)
assert len(many) == len(subset)
for derived, i in zip(many, subset):
ref, = d1s.apply_time_correction(tally, factors, index=[i])
np.testing.assert_array_equal(derived.mean, ref.mean)
np.testing.assert_array_equal(derived.std_dev, ref.std_dev)

# Original tally is unchanged
assert tally.filters == orig_filters
assert np.all(tally.sum == orig_sum)
assert np.all(tally.sum_sq == orig_sum_sq)
assert np.all(tally.mean == orig_mean)
assert np.all(tally.std_dev == orig_std_dev)