diff --git a/docs/docs/tutorials/fitting-bayesian.ipynb b/docs/docs/tutorials/fitting-bayesian.ipynb index 77e8b884..7f1019a9 100644 --- a/docs/docs/tutorials/fitting-bayesian.ipynb +++ b/docs/docs/tutorials/fitting-bayesian.ipynb @@ -262,8 +262,11 @@ "- `samples` (4000): the total number of posterior draws to generate across all chains;\n", "- `burn` (500): the number of initial *burn-in* iterations to discard — the sampler needs time to find the typical set of the posterior, and early samples are not representative;\n", "- `thin` (2): the *thinning* interval — only every second sample is kept, which reduces autocorrelation between consecutive draws;\n", + "- `population`: DREAM population count (number of parallel chains).\n", "\n", - "First, we switch to the BUMPS minimizer:\n" + "For expensive models, you can parallelise the population evaluation across multiple CPU processes by passing `n_workers`. This is covered in the next section.\n", + "\n", + "First, we switch to the BUMPS minimizer:" ] }, { @@ -298,6 +301,103 @@ "print('parameters:', result['param_names'])" ] }, + { + "cell_type": "markdown", + "id": "01c7cf9d", + "metadata": {}, + "source": [ + "## Faster sampling with multiprocessing\n", + "\n", + "Each DREAM generation evaluates the model for an entire *population* of\n", + "candidate parameter sets. By default these evaluations run sequentially in a\n", + "single process. For **expensive models** (where each evaluation takes tens of\n", + "milliseconds or more) you can speed up sampling by distributing the population\n", + "across multiple worker processes.\n", + "\n", + "Set `n_workers` to the number of parallel workers:\n", + "\n", + "```python\n", + "result = fitter.mcmc_sample(\n", + " ...,\n", + " n_workers=4, # evaluate 4 population members in parallel\n", + ")\n", + "```\n", + "\n", + "**How it works:** `mcmc_sample` serialises the BUMPS `FitProblem` (including the\n", + "model and all parameters) and spawns a `multiprocessing.Pool` using the\n", + "`spawn` start method. Each worker process independently evaluates a single\n", + "population member and returns the negative log-likelihood. The pool is\n", + "automatically closed after sampling finishes or if an error occurs.\n", + "\n", + "**When to use it:**\n", + "\n", + "| `n_workers` | Behaviour |\n", + "|---|---|\n", + "| `None` (default) | Sequential evaluation — no pool is created. |\n", + "| `1` | Also sequential, but explicitly requested. |\n", + "| `2` or more | Parallel evaluation using that many worker processes. |\n", + "\n", + "**Requirements:**\n", + "\n", + "- The `cloudpickle` package must be installed (it is a dependency of\n", + " `easyscience`).\n", + "- The BUMPS `FitProblem` and your model's fit function must be\n", + " *serialisable*. Models that capture non-picklable objects (e.g. open file\n", + " handles, bare `lambda` closures over module-level state) will raise\n", + " `FitError`. If this happens, fall back to `n_workers=None`.\n", + "- On Windows and macOS, the `spawn` start method is used automatically.\n", + " Make sure the sampling call is inside an `if __name__ == '__main__':`\n", + " guard when running from a script.\n", + "\n", + "**Choosing `n_workers`:** as a rule of thumb, set `n_workers` to the number\n", + "of physical CPU cores. If your model is very cheap (< 1 ms per evaluation),\n", + "the overhead of serialisation and inter-process communication may outweigh\n", + "the parallelism gain — try it and compare. The\n", + "`tools/benchmarks/sampling_mpi.py` script in the repository provides a\n", + "ready-made benchmark for your own model.\n", + "\n", + "```{note}\n", + "`n_workers` is capped at the DREAM population size (``population`` × number of\n", + "free parameters) because there is no benefit from more workers than there\n", + "are population members to evaluate.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c4fb9b", + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing\n", + "import time\n", + "\n", + "# For this simple model the overhead of multiprocessing is not worth it,\n", + "# but the pattern below shows exactly how to enable it for expensive models.\n", + "\n", + "n_cores = multiprocessing.cpu_count()\n", + "print(f'Detected {n_cores} CPU cores.')\n", + "\n", + "# Example (commented out because the simple Lorentzian model is too cheap\n", + "# to benefit from parallelism; uncomment and adjust n_workers for your own\n", + "# expensive model):\n", + "#\n", + "# t0 = time.perf_counter()\n", + "# result_parallel = mle_fitter.mcmc_sample(\n", + "# x=omega,\n", + "# y=intensity_obs,\n", + "# weights=1 / intensity_error,\n", + "# samples=10000,\n", + "# burn=500,\n", + "# thin=2,\n", + "# n_workers=min(4, n_cores), # use up to 4 workers\n", + "# )\n", + "# elapsed = time.perf_counter() - t0\n", + "# print(f'Parallel sampling took {elapsed:.1f} s')\n", + "# print(f'Drew {result_parallel[\"draws\"].shape[0]} samples')" + ] + }, { "cell_type": "markdown", "id": "8766b170", diff --git a/pixi.lock b/pixi.lock index 144dca8c..d6ab5106 100644 --- a/pixi.lock +++ b/pixi.lock @@ -6493,6 +6493,7 @@ packages: requires_dist: - asteval - bumps + - cloudpickle - dfo-ls - lmfit - numpy diff --git a/pyproject.toml b/pyproject.toml index e555f419..36ed3fb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,13 @@ classifiers = [ ] requires-python = '>=3.11' dependencies = [ - 'asteval', # Safely evaluate Python expressions from strings - 'lmfit', # Non-linear least squares fitting - 'bumps', # Bayesian uncertainty estimation - 'dfo-ls', # Derivative-free optimization - 'numpy', # Numerical computing - 'scipp', # Handling and analysis of scientific data + 'asteval', # Safely evaluate Python expressions from strings + 'lmfit', # Non-linear least squares fitting + 'bumps', # Bayesian uncertainty estimation + 'cloudpickle', # Serialize fitting closures for multiprocessing + 'dfo-ls', # Derivative-free optimization + 'numpy', # Numerical computing + 'scipp', # Handling and analysis of scientific data ] [project.optional-dependencies] @@ -211,7 +212,7 @@ select = [ # Ignore specific rules globally ignore = [ 'COM812', # https://docs.astral.sh/ruff/rules/missing-trailing-comma/ - # The following is replaced by 'D'/[tool.ruff.lint.pydocstyle] and [tool.pydoclint] + # The following is replaced by 'D' plus pydocstyle and pydoclint 'DOC', # https://docs.astral.sh/ruff/rules/#pydoclint-doc # Disable, as [tool.format_docstring] split one-line docstrings into the canonical multi-line layout 'D200', # https://docs.astral.sh/ruff/rules/unnecessary-multiline-docstring/ diff --git a/src/easyscience/fitting/fitter.py b/src/easyscience/fitting/fitter.py index ef6db4ab..f52f117d 100644 --- a/src/easyscience/fitting/fitter.py +++ b/src/easyscience/fitting/fitter.py @@ -421,9 +421,12 @@ def mcmc_sample( samples: int = 10000, burn: int = 2000, thin: int = 10, + chains: Optional[int] = None, population: Optional[int] = None, + seed: Optional[int] = None, vectorized: bool = False, sampler_kwargs: Optional[dict] = None, + n_workers: Optional[int] = None, progress_callback: Optional[Callable[[dict], Optional[bool]]] = None, abort_test: Optional[Callable[[], bool]] = None, ) -> dict: @@ -450,14 +453,26 @@ def mcmc_sample( thin : int, default=10 Thinning interval — only every ``thin``-th sample is kept, which reduces autocorrelation between consecutive draws. + chains : Optional[int], default=None + User-friendly alias for the BUMPS DREAM population count (number + of parallel chains). Mutually exclusive with ``population``. population : Optional[int], default=None - BUMPS DREAM population count (number of parallel chains). + BUMPS DREAM population count for advanced users. + seed : Optional[int], default=None + Best-effort random seed. BUMPS DREAM may use additional internal + RNG state that is not controlled by this seed, so exact + reproducibility is not guaranteed. vectorized : bool, default=False When ``True``, each x array may be multi-dimensional (e.g. an ``(N, M, 2)`` grid for a 2D model) and is left as-is. When ``False`` (default), each x array is expected to be 1-D. sampler_kwargs : Optional[dict], default=None Additional keyword arguments forwarded to the BUMPS DREAM sampler. + n_workers : Optional[int], default=None + Number of worker processes used to evaluate the DREAM population. + Values of ``None`` and ``1`` use BUMPS' sequential mapper. Values + greater than ``1`` require the BUMPS problem and fit function to be + pickleable. progress_callback : Optional[Callable[[dict], Optional[bool]]], default=None Optional callback invoked at each DREAM generation. The payload dict includes ``iteration`` and ``sampling: True``. @@ -473,7 +488,8 @@ def mcmc_sample( Raises ------ ValueError - If ``samples``, ``burn``, or ``thin`` are invalid. + If ``samples``, ``burn``, or ``thin`` are invalid, or if both + ``chains`` and ``population`` are provided with different values. RuntimeError If the active minimizer is not a BUMPS instance. """ @@ -505,8 +521,11 @@ def mcmc_sample( samples=samples, burn=burn, thin=thin, + chains=chains, population=population, + seed=seed, sampler_kwargs=sampler_kwargs, + n_workers=n_workers, progress_callback=progress_callback, abort_test=abort_test, ) diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index bd0a9b9a..6e49d055 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -4,7 +4,11 @@ from __future__ import annotations import copy +import io +import multiprocessing as mp +import pickle import warnings +import weakref from typing import Any from typing import Callable @@ -33,6 +37,207 @@ FIT_AVAILABLE_IDS_FILTERED.remove('pt') +_WORKER_PROBLEM = None + +_SCIPP_VARIABLE_KEY = '__easyscience_scipp_variable__' + + +def _serialize_worker_value(value: Any) -> Any: + try: + import scipp as sc + except ImportError: + sc = None + + if sc is not None and isinstance(value, sc.Variable): + return { + _SCIPP_VARIABLE_KEY: True, + 'value': value.value, + 'variance': value.variance, + 'unit': str(value.unit), + } + if isinstance(value, (weakref.ReferenceType, weakref.KeyedRef)): + return None + if isinstance(value, dict): + return {key: _serialize_worker_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_serialize_worker_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_serialize_worker_value(item) for item in value) + if isinstance(value, set): + return {_serialize_worker_value(item) for item in value} + return value + + +def _deserialize_worker_value(value: Any) -> Any: + if isinstance(value, dict): + if value.get(_SCIPP_VARIABLE_KEY): + import scipp as sc + + return sc.scalar( + value['value'], + unit=value['unit'], + variance=value['variance'], + ) + return {key: _deserialize_worker_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_deserialize_worker_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_deserialize_worker_value(item) for item in value) + if isinstance(value, set): + return {_deserialize_worker_value(item) for item in value} + return value + + +def _collect_object_state(obj: object) -> dict: + state = {} + if hasattr(obj, '__dict__'): + state['__dict__'] = { + key: _serialize_worker_value(value) + for key, value in obj.__dict__.items() + if key != '__old_class__' + } + + slots = {} + for cls in type(obj).mro(): + cls_slots = getattr(cls, '__slots__', ()) + if isinstance(cls_slots, str): + cls_slots = (cls_slots,) + for slot in cls_slots: + if slot in ('__dict__', '__weakref__', '_global_object'): + continue + if hasattr(obj, slot): + slots[slot] = _serialize_worker_value(getattr(obj, slot)) + state['__slots__'] = slots + return state + + +def _restore_object_state(cls: type, state: dict) -> object: + obj = cls.__new__(cls) + if hasattr(obj, '__dict__'): + obj.__dict__.update(_deserialize_worker_value(state.get('__dict__', {}))) + for slot, value in _deserialize_worker_value(state.get('__slots__', {})).items(): + object.__setattr__(obj, slot, value) + + for key, value in getattr(obj, '_kwargs', {}).items(): + object.__setattr__(obj, key, value) + + from easyscience import global_object + + if hasattr(obj, '_global_object') or any( + '_global_object' in getattr(base, '__slots__', ()) for base in cls.mro() + ): + object.__setattr__(obj, '_global_object', global_object) + return obj + + +def _reduce_object_state(obj: object) -> tuple: + cls = getattr(obj, '__old_class__', obj.__class__) + return _restore_object_state, (cls, _collect_object_state(obj)) + + +def _restore_none() -> None: + return None + + +def _reduce_weakref(obj: weakref.ReferenceType) -> tuple: + return _restore_none, () + + +def _problem_pickler_class(): + """Build a Pickler subclass that handles BUMPS problem reduction locally. + + Uses ``reducer_override`` (instance-scoped) instead of mutating + ``__reduce__`` on shared classes or ``copyreg.dispatch_table`` — those + globals would race with any concurrent pickle on another thread. + """ + from cloudpickle import CloudPickler + + from easyscience.base_classes.based_base import BasedBase + from easyscience.variable.descriptor_base import DescriptorBase + + _parent_reducer = CloudPickler.reducer_override + + class _BumpsProblemPickler(CloudPickler): + def reducer_override(self, obj): + if isinstance(obj, (weakref.ReferenceType, weakref.KeyedRef)): + return _restore_none, () + if isinstance(obj, (BasedBase, DescriptorBase)): + return _reduce_object_state(obj) + return _parent_reducer(self, obj) + + return _BumpsProblemPickler + + +def _init_bumps_worker(problem_bytes: bytes) -> None: + global _WORKER_PROBLEM + _WORKER_PROBLEM = pickle.loads(problem_bytes) + + from easyscience import global_object + + global_object.stack.enabled = False + + +def _evaluate_bumps_point(point: np.ndarray) -> float: + if _WORKER_PROBLEM is None: + raise RuntimeError('BUMPS worker problem was not initialized') + return float(_WORKER_PROBLEM.nllf(point)) + + +class BumpsPoolMapper: + """Multiprocessing mapper for BUMPS DREAM population evaluation.""" + + def __init__(self, problem: FitProblem, n_workers: int): + self._pool = None + self.n_workers = n_workers + try: + pickler_cls = _problem_pickler_class() + buffer = io.BytesIO() + pickler_cls(buffer).dump(problem) + problem_bytes = buffer.getvalue() + except Exception as exc: + raise FitError( + 'BUMPS multiprocessing requires the FitProblem and fit function to be ' + 'serializable. Use n_workers=1 for sequential sampling.' + ) from exc + + context = mp.get_context('spawn') + self._pool = context.Pool( + processes=n_workers, + initializer=_init_bumps_worker, + initargs=(problem_bytes,), + ) + + def __call__(self, population: np.ndarray) -> list[float]: + # BUMPS may pass either a single point (1D) or a population (2D). + # Always reshape to 2D so list() produces one element per chain member. + pop = np.atleast_2d(np.asarray(population)) + n_points = pop.shape[0] + # Distribute the population across workers in as few tasks as possible. + # DREAM evaluations are individually cheap, so per-task IPC overhead + # (pickling + queue round-trip) dominates when chunksize=1. Sending one + # chunk per worker amortizes that overhead across the whole generation. + chunksize = max(1, (n_points + self.n_workers - 1) // self.n_workers) + results = self._pool.map(_evaluate_bumps_point, list(pop), chunksize=chunksize) + + # Safety check: BUMPS DREAM state corruption can occur if the + # mapper returns a different number of values than expected. + if len(results) != n_points: + raise RuntimeError( + f'Mapper returned {len(results)} results for {n_points} population points' + ) + return results + + def close(self) -> None: + self.terminate() + + def terminate(self) -> None: + if self._pool is None: + return + self._pool.terminate() + self._pool.join() + self._pool = None + + class Bumps(MinimizerBase): """ This is a wrapper to Bumps: https://bumps.readthedocs.io/ It allows @@ -263,6 +468,43 @@ def _resolve_fitclass(method: str): return fitclass raise FitError(f'Unknown BUMPS fitting method: {method}') + @staticmethod + def _resolve_population_alias(chains: int | None, population: int | None) -> int | None: + """Resolve the DREAM population count from the ``chains`` alias. + + Both ``chains`` (user-friendly name) and ``population`` (BUMPS + native name) refer to the same DREAM ``pop`` parameter. This + helper enforces that at most one is provided and returns the + resolved value. + + Parameters + ---------- + chains : int | None + User-friendly alias for the DREAM population count. + population : int | None + BUMPS-native DREAM population count. + + Returns + ------- + int | None + The resolved population count, or ``None`` if neither was + provided. + + Raises + ------ + ValueError + If both ``chains`` and ``population`` are provided with + different values. + """ + if chains is not None and population is not None: + if chains != population: + raise ValueError( + f'Conflicting population arguments: chains={chains}, ' + f'population={population}. Only provide one.' + ) + return chains + return chains if chains is not None else population + def _build_progress_payload( self, problem, iteration: int, point: np.ndarray, nllf: float ) -> dict: @@ -389,10 +631,13 @@ def mcmc_sample( samples: int = 10000, burn: int = 2000, thin: int = 10, + chains: int | None = None, population: int | None = None, + seed: int | None = None, sampler_kwargs: dict | None = None, progress_callback: Callable[[dict], bool | None] | None = None, abort_test: Callable[[], bool] | None = None, + n_workers: int | None = None, ) -> dict: """Run Bayesian MCMC sampling using the BUMPS DREAM sampler. @@ -415,8 +660,18 @@ def mcmc_sample( Burn-in steps. thin : int, default=10 Thinning interval. + chains : int | None, default=None + User-friendly alias for BUMPS DREAM population count (number of + parallel chains). population : int | None, default=None - BUMPS DREAM population count (number of parallel chains). + BUMPS DREAM population count for advanced users. + seed : int | None, default=None + Best-effort random seed. Calls ``numpy.random.seed(seed)`` + before DREAM starts, which affects the *global* NumPy RNG + state and may interact with other code in the process. + BUMPS DREAM uses additional internal RNG state that is + **not** controlled by this seed, so exact reproducibility + across runs is **not** guaranteed. sampler_kwargs : dict | None, default=None Additional keyword arguments forwarded to `bumps.fitters.fit`. progress_callback : Callable[[dict], bool | None] | None, default=None @@ -427,6 +682,11 @@ def mcmc_sample( Optional callback that returns ``True`` to signal that sampling should be aborted. Called periodically during the DREAM sampling loop. + n_workers : int | None, default=None + Number of worker processes used to evaluate the DREAM population. + Values of ``None`` and ``1`` use BUMPS' sequential mapper. Values + greater than ``1`` require the BUMPS problem and fit function to be + pickleable. Returns ------- @@ -437,10 +697,13 @@ def mcmc_sample( Raises ------ ValueError - If the input shapes or weights are invalid, or if + If the input shapes or weights are invalid, if both ``chains`` + and ``population`` are provided with different values, or if ``progress_callback`` is not callable. FitError - If DREAM sampling was aborted by the user (via ``abort_test``). + If DREAM sampling was aborted by the user (via ``abort_test``), or + if multiprocessing was requested for a problem that cannot be + serialized for worker processes. Exception Re-raised from DREAM fitting if any unexpected error occurs (parameter values are restored beforehand). @@ -479,13 +742,31 @@ def mcmc_sample( curve = model_func(x, y, weights) problem = FitProblem(curve) + # Best-effort seed: sets numpy's global RNG state just before DREAM starts. + if seed is not None: + np.random.seed(seed) + + # Resolve population parameter + pop = self._resolve_population_alias(chains, population) + # Build DREAM kwargs dream_kwargs: dict = {'samples': samples, 'burn': burn, 'thin': thin} - if population is not None: - dream_kwargs['pop'] = population + if pop is not None: + dream_kwargs['pop'] = pop if sampler_kwargs: dream_kwargs.update(sampler_kwargs) + resolved_pop = int(dream_kwargs.get('pop', 10)) + if resolved_pop <= 0: + raise ValueError('DREAM population must be a positive integer.') + + mapper = None + if n_workers is not None: + if n_workers < 1: + raise ValueError('n_workers must be at least 1.') + if n_workers > 1: + mapper = BumpsPoolMapper(problem, n_workers=min(n_workers, resolved_pop)) + # Build monitors (same pattern as classical Bumps.fit()) monitors = [] if progress_callback is not None: @@ -494,7 +775,7 @@ def mcmc_sample( # Compute total DREAM steps for progress display (burn + sampling generations). # BUMPS DREAM default population count is 10 when not specified by the user. _dream_default_pop = 10 - pop_val = population if population is not None else _dream_default_pop + pop_val = pop if pop is not None else _dream_default_pop _total_steps = burn + (samples + pop_val - 1) // pop_val monitors.append( BumpsProgressMonitor( @@ -512,6 +793,7 @@ def mcmc_sample( problem=problem, monitors=monitors, abort_test=abort_test if abort_test is not None else (lambda: False), + mapper=mapper, **dream_kwargs, ) driver.clip() @@ -527,9 +809,14 @@ def mcmc_sample( if result_state is None: raise FitError('Sampling aborted by user') except Exception: + if mapper is not None: + mapper.terminate() + mapper = None self._restore_parameter_values() raise finally: + if mapper is not None: + mapper.close() global_object.stack.enabled = stack_status draws = result_state.draw().points diff --git a/tests/integration/fitting/test_multi_fitter.py b/tests/integration/fitting/test_multi_fitter.py index e4b87ea1..e3045607 100644 --- a/tests/integration/fitting/test_multi_fitter.py +++ b/tests/integration/fitting/test_multi_fitter.py @@ -1,6 +1,11 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +import subprocess +import sys +import textwrap +from pathlib import Path + import numpy as np import pytest @@ -498,3 +503,236 @@ def test_sampler_kwargs_forwarded(self): assert result['draws'].ndim == 2 assert result['draws'].shape[0] > 0 + + +class TestSampleAliasResolution: + def test_conflicting_chains_and_population_raises(self): + """Passing both chains and population with different values must raise.""" + sp = AbsSin(0.354, 3.05) + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + x = np.linspace(0, 5, 50) + y = np.sin(x) + weights = np.ones_like(x) + + with pytest.raises(ValueError, match='Conflicting population arguments'): + f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=10, burn=5, thin=1, chains=3, population=5 + ) + + def test_chains_and_population_equal_is_ok(self): + """Passing chains == population should succeed (no conflict).""" + sp = AbsSin(0.354, 3.05) + sp.offset.fixed = False + sp.phase.fixed = False + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + x = np.linspace(0, 5, 50) + y = np.sin(x) + weights = np.ones_like(x) + + # Should not raise ValueError — chains and population are equal. + # DREAM needs a sufficient population; 5 is a safe minimum. + result = f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=100, burn=20, thin=2, chains=5, population=5 + ) + assert 'draws' in result + + +class TestSampleSeedReproducibility: + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_seed_produces_valid_draws(self): + """Running mcmc_sample() with a seed must produce valid draws.""" + ref_sin = AbsSin(0.2, np.pi) + sp = AbsSin(0.354, 3.05) + + x = np.linspace(0, 5, 50) + y = ref_sin(x) + weights = np.ones_like(x) + + sp.offset.fixed = False + sp.phase.fixed = False + + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + result = f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=100, burn=20, thin=2, seed=42 + ) + + assert result['draws'].ndim == 2 + assert result['draws'].shape[0] > 0 + assert result['draws'].shape[1] == len(result['param_names']) + # logp should be present (may be None if not computed) + assert 'logp' in result + + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_different_seeds_both_produce_valid_draws(self): + """Running mcmc_sample() with different seeds should each produce valid draws.""" + ref_sin = AbsSin(0.2, np.pi) + sp = AbsSin(0.354, 3.05) + + x = np.linspace(0, 5, 50) + y = ref_sin(x) + weights = np.ones_like(x) + + sp.offset.fixed = False + sp.phase.fixed = False + + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + result1 = f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=100, burn=20, thin=2, seed=42 + ) + result2 = f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=100, burn=20, thin=2, seed=12345 + ) + + # Both must produce valid draws + assert result1['draws'].shape[0] > 0 + assert result2['draws'].shape[0] > 0 + assert result1['draws'].ndim == 2 + assert result2['draws'].ndim == 2 + + +class TestSampleMultiprocessing: + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_n_workers_one_uses_sequential_mapper(self): + """n_workers=1 should behave like the default sequential DREAM mapper.""" + ref_sin = AbsSin(0.2, np.pi) + sp = AbsSin(0.354, 3.05) + + x = np.linspace(0, 5, 50) + y = ref_sin(x) + weights = np.ones_like(x) + + sp.offset.fixed = False + sp.phase.fixed = False + + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + result = f.mcmc_sample( + x=[x], y=[y], weights=[weights], samples=50, burn=10, thin=2, n_workers=1 + ) + + assert result['draws'].ndim == 2 + assert result['draws'].shape[0] > 0 + assert result['draws'].shape[1] == len(result['param_names']) + + def test_n_workers_must_be_positive(self): + """n_workers must be positive when explicitly provided.""" + sp = AbsSin(0.354, 3.05) + sp.offset.fixed = False + sp.phase.fixed = False + + f = MultiFitter([sp], [sp]) + try: + f.switch_minimizer('Bumps') + except AttributeError: + pytest.skip('BUMPS is not installed') + + x = np.linspace(0, 5, 50) + y = np.sin(x) + weights = np.ones_like(x) + + with pytest.raises(ValueError, match='n_workers must be at least 1'): + f.mcmc_sample(x=[x], y=[y], weights=[weights], samples=10, burn=5, thin=1, n_workers=0) + + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_n_workers_two_produces_valid_draws(self, tmp_path): + """n_workers>1 should evaluate DREAM populations through process workers.""" + repo_root = str(Path(__file__).resolve().parents[3]) + test_file = str(Path(__file__).resolve()) + script = tmp_path / 'run_bumps_multiprocessing_sample.py' + script.write_text( + textwrap.dedent( + f""" + import importlib.util + import multiprocessing as mp + import sys + + sys.path.insert(0, {repo_root!r}) + sys.path.insert(0, {repo_root + '/src'!r}) + + import numpy as np + + from easyscience.fitting.multi_fitter import MultiFitter + + # Load the AbsSin model straight from this test file to avoid + # depending on ``tests`` being importable as a package (it is a + # plain namespace dir and can be shadowed by sibling projects). + _spec = importlib.util.spec_from_file_location('_mp_test_models', {test_file!r}) + _models = importlib.util.module_from_spec(_spec) + _spec.loader.exec_module(_models) + AbsSin = _models.AbsSin + + + def main(): + ref_sin = AbsSin(0.2, np.pi) + sp = AbsSin(0.354, 3.05) + + x = np.linspace(0, 5, 40) + y = ref_sin(x) + weights = np.ones_like(x) + + sp.offset.fixed = False + sp.phase.fixed = False + + f = MultiFitter([sp], [sp]) + f.switch_minimizer('Bumps') + result = f.mcmc_sample( + x=[x], + y=[y], + weights=[weights], + samples=50, + burn=10, + thin=2, + population=5, + n_workers=2, + ) + + assert result['draws'].ndim == 2 + assert result['draws'].shape[0] > 0 + assert result['draws'].shape[1] == len(result['param_names']) + + + if __name__ == '__main__': + mp.freeze_support() + main() + """ + ), + encoding='utf-8', + ) + + try: + completed = subprocess.run( + [sys.executable, str(script)], + cwd=repo_root, + capture_output=True, + text=True, + timeout=60, + check=False, + ) + except subprocess.TimeoutExpired: + pytest.fail('n_workers=2 sampling subprocess timed out after 60 seconds') + + assert completed.returncode == 0, completed.stdout + completed.stderr diff --git a/tests/unit/fitting/minimizers/test_minimizer_bumps.py b/tests/unit/fitting/minimizers/test_minimizer_bumps.py index 7a223f2c..1a1e4dd6 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_bumps.py +++ b/tests/unit/fitting/minimizers/test_minimizer_bumps.py @@ -1,6 +1,9 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +import io +import pickle +import weakref from unittest.mock import ANY from unittest.mock import MagicMock from unittest.mock import patch @@ -9,8 +12,14 @@ import pytest import easyscience.fitting.minimizers.minimizer_bumps +import easyscience.fitting.minimizers.minimizer_bumps as _bumps_mod from easyscience.fitting.minimizers.bumps_utils import BumpsProgressMonitor from easyscience.fitting.minimizers.minimizer_bumps import Bumps +from easyscience.fitting.minimizers.minimizer_bumps import BumpsPoolMapper +from easyscience.fitting.minimizers.minimizer_bumps import _evaluate_bumps_point +from easyscience.fitting.minimizers.minimizer_bumps import _init_bumps_worker +from easyscience.fitting.minimizers.minimizer_bumps import _problem_pickler_class +from easyscience.fitting.minimizers.minimizer_bumps import _restore_none from easyscience.fitting.minimizers.utils import FitError @@ -1059,3 +1068,396 @@ def test_abort_test_passed_to_fit_driver(self, minimizer: Bumps, monkeypatch) -> call_kwargs = mock_FitDriver.call_args.kwargs assert callable(call_kwargs['abort_test']) assert call_kwargs['abort_test'] is not (lambda: False) + + +# =================================================================== +# _resolve_population_alias (static helper) +# =================================================================== + + +class TestResolvePopulationAlias: + """Tests for ``Bumps._resolve_population_alias``.""" + + def test_both_none_returns_none(self) -> None: + assert Bumps._resolve_population_alias(None, None) is None + + def test_chains_only_returns_chains(self) -> None: + assert Bumps._resolve_population_alias(5, None) == 5 + + def test_population_only_returns_population(self) -> None: + assert Bumps._resolve_population_alias(None, 7) == 7 + + def test_both_equal_returns_value(self) -> None: + assert Bumps._resolve_population_alias(5, 5) == 5 + + def test_both_different_raises(self) -> None: + with pytest.raises(ValueError, match='Conflicting population'): + Bumps._resolve_population_alias(3, 10) + + def test_chains_zero_is_valid(self) -> None: + """Zero is a valid (though unusual) population value.""" + assert Bumps._resolve_population_alias(0, None) == 0 + + def test_population_zero_is_valid(self) -> None: + assert Bumps._resolve_population_alias(None, 0) == 0 + + +# =================================================================== +# Worker helper functions: _evaluate_bumps_point, _init_bumps_worker +# =================================================================== + + +class TestWorkerFunctions: + def test_evaluate_raises_when_problem_not_initialized(self, monkeypatch): + monkeypatch.setattr(_bumps_mod, '_WORKER_PROBLEM', None) + with pytest.raises(RuntimeError, match='not initialized'): + _evaluate_bumps_point(np.array([1.0])) + + def test_evaluate_calls_nllf_and_returns_python_float(self, monkeypatch): + mock_problem = MagicMock() + mock_problem.nllf.return_value = np.float64(3.5) + monkeypatch.setattr(_bumps_mod, '_WORKER_PROBLEM', mock_problem) + + result = _evaluate_bumps_point(np.array([1.0, 2.0])) + + assert isinstance(result, float) + assert result == 3.5 + mock_problem.nllf.assert_called_once() + np.testing.assert_array_equal(mock_problem.nllf.call_args[0][0], np.array([1.0, 2.0])) + + def test_init_worker_populates_global_problem(self, monkeypatch): + monkeypatch.setattr(_bumps_mod, '_WORKER_PROBLEM', None) + _init_bumps_worker(pickle.dumps({'sentinel': True})) + assert _bumps_mod._WORKER_PROBLEM == {'sentinel': True} + + def test_init_worker_disables_global_stack(self, monkeypatch): + from easyscience import global_object + + monkeypatch.setattr(_bumps_mod, '_WORKER_PROBLEM', None) + global_object.stack.enabled = True + _init_bumps_worker(pickle.dumps(42)) + assert global_object.stack.enabled is False + + +# =================================================================== +# _problem_pickler_class +# =================================================================== + + +class TestProblemPicklerClass: + def test_returns_cloudpickler_subclass(self): + from cloudpickle import CloudPickler + + assert issubclass(_problem_pickler_class(), CloudPickler) + + def test_reducer_override_replaces_weakref_with_none_restorer(self): + cls = _problem_pickler_class() + + class _Dummy: + pass + + obj = _Dummy() + ref = weakref.ref(obj) + + result = cls(io.BytesIO()).reducer_override(ref) + assert result == (_restore_none, ()) + + def test_reducer_override_falls_through_for_plain_dict(self): + cls = _problem_pickler_class() + result = cls(io.BytesIO()).reducer_override({'key': 'val'}) + assert result is NotImplemented + + def test_weakref_survives_round_trip_as_none(self): + cls = _problem_pickler_class() + + class _Dummy: + pass + + obj = _Dummy() + + class _Container: + ref = weakref.ref(obj) + + buf = io.BytesIO() + cls(buf).dump(_Container()) + buf.seek(0) + restored = pickle.load(buf) + assert restored.ref is None + + +# =================================================================== +# BumpsPoolMapper — lifecycle (terminate / close) +# =================================================================== + + +class TestBumpsPoolMapperLifecycle: + def _mapper(self): + m = BumpsPoolMapper.__new__(BumpsPoolMapper) + m._pool = MagicMock() + m.n_workers = 2 + return m + + def test_terminate_shuts_down_pool(self): + mapper = self._mapper() + pool = mapper._pool + mapper.terminate() + pool.terminate.assert_called_once() + pool.join.assert_called_once() + assert mapper._pool is None + + def test_terminate_is_idempotent_when_pool_is_none(self): + mapper = BumpsPoolMapper.__new__(BumpsPoolMapper) + mapper._pool = None + mapper.terminate() # must not raise + + def test_close_delegates_to_terminate(self): + mapper = self._mapper() + pool = mapper._pool + mapper.close() + pool.terminate.assert_called_once() + assert mapper._pool is None + + +# =================================================================== +# BumpsPoolMapper — __call__ +# =================================================================== + + +class TestBumpsPoolMapperCall: + def _mapper(self, map_return): + m = BumpsPoolMapper.__new__(BumpsPoolMapper) + m._pool = MagicMock() + m._pool.map.return_value = map_return + m.n_workers = 2 + return m + + def test_maps_2d_population_and_chunks_across_workers(self): + pop = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + mapper = self._mapper([1.0, 2.0, 3.0]) + + result = mapper(pop) + + assert result == [1.0, 2.0, 3.0] + # 3 points across 2 workers => ceil(3/2) = 2 points per IPC task, + # amortizing per-task multiprocessing overhead over the generation. + assert mapper._pool.map.call_args.kwargs.get('chunksize') == 2 + + def test_reshapes_1d_point_to_single_row(self): + mapper = self._mapper([7.0]) + + result = mapper(np.array([1.0, 2.0])) + + assert result == [7.0] + points_arg = mapper._pool.map.call_args[0][1] + assert len(points_arg) == 1 + # list(atleast_2d([1., 2.])) produces 1D rows, one per chain member + np.testing.assert_array_equal(points_arg[0], np.array([1.0, 2.0])) + + def test_raises_on_result_count_mismatch(self): + mapper = self._mapper([42.0]) # one result for two points + with pytest.raises( + RuntimeError, match='Mapper returned 1 results for 2 population points' + ): + mapper(np.array([[1.0], [2.0]])) + + +# =================================================================== +# BumpsPoolMapper — __init__ (serialization + pool creation) +# =================================================================== + + +class TestBumpsPoolMapperInit: + @staticmethod + def _patch_pickler(monkeypatch, written_bytes=b'fake_problem'): + class _FakePickler: + def __init__(self, buf): + self._buf = buf + + def dump(self, obj): + self._buf.write(written_bytes) + + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, + '_problem_pickler_class', + lambda: _FakePickler, + ) + + def test_creates_spawn_pool_with_correct_args(self, monkeypatch): + self._patch_pickler(monkeypatch) + mock_pool = MagicMock() + mock_context = MagicMock() + mock_context.Pool.return_value = mock_pool + monkeypatch.setattr(_bumps_mod.mp, 'get_context', lambda _: mock_context) + + mapper = BumpsPoolMapper(MagicMock(), n_workers=3) + + assert mapper._pool is mock_pool + mock_context.Pool.assert_called_once_with( + processes=3, + initializer=_bumps_mod._init_bumps_worker, + initargs=(b'fake_problem',), + ) + + def test_serialization_failure_raises_fit_error(self, monkeypatch): + class _BadPickler: + def __init__(self, buf): + pass + + def dump(self, obj): + raise TypeError('not serializable') + + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, + '_problem_pickler_class', + lambda: _BadPickler, + ) + with pytest.raises(FitError, match='serializable'): + BumpsPoolMapper(MagicMock(), n_workers=2) + + +# =================================================================== +# Bumps.mcmc_sample() — n_workers wiring +# =================================================================== + + +class TestBumpsSampleNWorkers: + @pytest.fixture + def minimizer(self) -> Bumps: + return Bumps( + obj='obj', + fit_function='fit_function', + minimizer_enum=MagicMock(package='bumps', method='amoeba'), + ) + + @pytest.fixture(autouse=True) + def _mock_bumps_internals(self, monkeypatch): + import bumps.fitters + import bumps.names + + monkeypatch.setattr(bumps.fitters, 'DreamFit', MagicMock()) + monkeypatch.setattr(bumps.names, 'FitProblem', MagicMock(return_value=MagicMock())) + monkeypatch.setattr( + Bumps, '_make_model', MagicMock(return_value=MagicMock(return_value=MagicMock())) + ) + + def _setup_driver(self, monkeypatch): + from easyscience import global_object + + global_object.stack.enabled = False + + mock_state = MagicMock() + mock_state.draw.return_value.points = np.array([[1.0]]) + mock_state.logp = None + mock_driver = MagicMock() + mock_driver.clip = MagicMock() + mock_driver.fit.return_value = (np.array([1.0]), 0.0) + mock_driver.fitter.state = mock_state + + mock_FitDriver = MagicMock(return_value=mock_driver) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, 'FitDriver', mock_FitDriver + ) + return mock_FitDriver, mock_driver + + def _patch_mapper(self, monkeypatch): + mock_mapper = MagicMock() + mock_cls = MagicMock(return_value=mock_mapper) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, 'BumpsPoolMapper', mock_cls + ) + return mock_cls, mock_mapper + + def test_n_workers_zero_raises(self, minimizer, monkeypatch): + self._setup_driver(monkeypatch) + with pytest.raises(ValueError, match='n_workers must be at least 1'): + minimizer.mcmc_sample( + x=np.array([1.0]), y=np.array([0.1]), weights=np.array([1.0]), n_workers=0 + ) + + def test_n_workers_one_does_not_create_mapper(self, minimizer, monkeypatch): + mock_FitDriver, _ = self._setup_driver(monkeypatch) + mock_cls, _ = self._patch_mapper(monkeypatch) + + minimizer.mcmc_sample( + x=np.array([1.0]), y=np.array([0.1]), weights=np.array([1.0]), n_workers=1 + ) + + mock_cls.assert_not_called() + assert mock_FitDriver.call_args.kwargs['mapper'] is None + + def test_n_workers_two_creates_mapper_and_passes_to_driver(self, minimizer, monkeypatch): + mock_FitDriver, _ = self._setup_driver(monkeypatch) + mock_cls, mock_mapper = self._patch_mapper(monkeypatch) + + minimizer.mcmc_sample( + x=np.array([1.0]), + y=np.array([0.1]), + weights=np.array([1.0]), + n_workers=2, + population=5, + ) + + mock_cls.assert_called_once() + assert mock_FitDriver.call_args.kwargs['mapper'] is mock_mapper + mock_mapper.close.assert_called_once() + + def test_n_workers_clipped_to_population_size(self, minimizer, monkeypatch): + self._setup_driver(monkeypatch) + mock_cls, _ = self._patch_mapper(monkeypatch) + + minimizer.mcmc_sample( + x=np.array([1.0]), + y=np.array([0.1]), + weights=np.array([1.0]), + n_workers=8, + population=3, + ) + + assert mock_cls.call_args.kwargs['n_workers'] == 3 + + def test_mapper_terminated_on_driver_exception(self, minimizer, monkeypatch): + from easyscience import global_object + + global_object.stack.enabled = False + mock_driver = MagicMock() + mock_driver.clip = MagicMock() + mock_driver.fit.side_effect = RuntimeError('driver failed') + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, + 'FitDriver', + MagicMock(return_value=mock_driver), + ) + mock_cls, mock_mapper = self._patch_mapper(monkeypatch) + + with pytest.raises(RuntimeError, match='driver failed'): + minimizer.mcmc_sample( + x=np.array([1.0]), y=np.array([0.1]), weights=np.array([1.0]), n_workers=2 + ) + + mock_mapper.terminate.assert_called_once() + mock_mapper.close.assert_not_called() + + def test_mapper_closed_on_success(self, minimizer, monkeypatch): + self._setup_driver(monkeypatch) + mock_cls, mock_mapper = self._patch_mapper(monkeypatch) + + minimizer.mcmc_sample( + x=np.array([1.0]), y=np.array([0.1]), weights=np.array([1.0]), n_workers=2 + ) + + mock_mapper.close.assert_called_once() + mock_mapper.terminate.assert_not_called() + + def test_sample_conflicting_population_raises(self, minimizer, monkeypatch): + self._setup_driver(monkeypatch) + with pytest.raises(ValueError, match='Conflicting population'): + minimizer.mcmc_sample( + x=np.array([1.0]), + y=np.array([0.1]), + weights=np.array([1.0]), + chains=5, + population=10, + samples=10, + burn=0, + thin=1, + ) diff --git a/tests/unit/fitting/test_fitter.py b/tests/unit/fitting/test_fitter.py index 634492c5..14377be6 100644 --- a/tests/unit/fitting/test_fitter.py +++ b/tests/unit/fitting/test_fitter.py @@ -314,6 +314,37 @@ def test_basic(self, fitter: Fitter): assert kw['progress_callback'] is None assert fitter._dependent_dims == 'dims' + def test_chains_seed_and_n_workers_forwarded(self, fitter: Fitter): + """chains, seed and n_workers are forwarded to minimizer.mcmc_sample().""" + fitter._precompute_reshaping = MagicMock( + return_value=('x_fit', 'x_new', 'y_new', 'w_new', 'dims') + ) + fitter._fit_function_wrapper = MagicMock(return_value='wrapped') + fitter._minimizer = MagicMock() + fitter._minimizer.package = 'bumps' + fitter._minimizer.mcmc_sample = MagicMock( + return_value={ + 'draws': [], + 'param_names': [], + 'internal_bumps_object': None, + 'logp': None, + } + ) + + fitter.mcmc_sample( + x=np.array([1.0]), + y=np.array([0.1]), + weights=np.array([1.0]), + chains=7, + seed=42, + n_workers=2, + ) + + kw = fitter._minimizer.mcmc_sample.call_args.kwargs + assert kw['chains'] == 7 + assert kw['seed'] == 42 + assert kw['n_workers'] == 2 + def test_raises_if_not_bumps(self, fitter: Fitter): """RuntimeError raised when the active minimizer is not BUMPS.""" fitter._precompute_reshaping = MagicMock( diff --git a/tools/benchmarks/sampling_mpi.py b/tools/benchmarks/sampling_mpi.py new file mode 100644 index 00000000..54c35902 --- /dev/null +++ b/tools/benchmarks/sampling_mpi.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Quick benchmark: Bayesian DREAM sampling with multiprocessing. + +Runs the same small sampling problem sequentially and with process workers, +printing wall-clock times so you can judge the speedup. +""" + +import time +import warnings + +import numpy as np + +from easyscience import ObjBase +from easyscience import Parameter +from easyscience.fitting.multi_fitter import MultiFitter + +# -- simple test model -------------------------------------------------------- + +# Simulate an expensive model by adding a configurable CPU burn per evaluation. +# Set to 0.0 for the trivial model; try 0.02–0.1 to see multiprocessing speedup. +_MODEL_DELAY = 0.09 # seconds of CPU work per model call + + +class Line(ObjBase): + m: Parameter + c: Parameter + + def __init__(self, m_val: float, c_val: float): + super().__init__( + 'line', + m=Parameter('m', m_val), + c=Parameter('c', c_val), + ) + + def __call__(self, x: np.ndarray) -> np.ndarray: + if _MODEL_DELAY > 0: + # burn CPU to simulate a real physics model + t0 = time.perf_counter() + while time.perf_counter() - t0 < _MODEL_DELAY: + _ = np.sum(np.sin(x) ** 2 + np.cos(x) ** 2) + return self.m.value * x + self.c.value + +# -- helpers ------------------------------------------------------------------ + +def run_sample(n_workers: int | None, **sample_kwargs) -> tuple[dict, float]: + """Run one DREAM sampling call and return (result_dict, wall_seconds).""" + x = np.linspace(0, 10, 60) + y_true = 2.5 * x + 1.3 + rng = np.random.default_rng(42) + y = y_true + rng.normal(0, 0.3, size=x.shape) + weights = np.full_like(x, 1.0 / 0.3) + + model = Line(2.0, 1.0) + model.m.fixed = False + model.c.fixed = False + + fitter = MultiFitter([model], [model]) + fitter.switch_minimizer('Bumps') + + t0 = time.perf_counter() + result = fitter.mcmc_sample( + x=[x], + y=[y], + weights=[weights], + n_workers=n_workers, + **sample_kwargs, + ) + elapsed = time.perf_counter() - t0 + return result, elapsed + +def summarise(label: str, result: dict, elapsed: float) -> None: + draws = result['draws'] + print(f' {label:>12s} {elapsed:6.2f} s ' + f'draws shape {draws.shape} ' + f'params: {result["param_names"]}') + +# -- main --------------------------------------------------------------------- + +def main() -> None: + sample_kwargs = dict(samples=200, burn=50, thin=2, population=5, seed=123) + + print('Bayesian multiprocessing quick test') + print('-----------------------------------') + print(f' config: {sample_kwargs}') + print() + + # 1. sequential (default) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + res_seq, t_seq = run_sample(n_workers=None, **sample_kwargs) + summarise('sequential', res_seq, t_seq) + + # 2. n_workers=1 (same as sequential, but explicit) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + res_w1, t_w1 = run_sample(n_workers=1, **sample_kwargs) + summarise('n_workers=1', res_w1, t_w1) + + # 3. n_workers=2 (actual multiprocessing) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + res_w2, t_w2 = run_sample(n_workers=2, **sample_kwargs) + summarise('n_workers=2', res_w2, t_w2) + + # 4. n_workers=4 + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + res_w4, t_w4 = run_sample(n_workers=4, **sample_kwargs) + summarise('n_workers=4', res_w4, t_w4) + + print() + for label, t_val in [('n_workers=2', t_w2), ('n_workers=4', t_w4)]: + ratio = t_seq / t_val + tag = f'{ratio:.1f}× speedup' if ratio > 1 else f'{1/ratio:.1f}× slower' + print(f' {label:>12s} {tag} (seq {t_seq:.2f}s → {t_val:.2f}s)') + +if __name__ == '__main__': + main()