diff --git a/pixi.lock b/pixi.lock index 6b3a38bc..167028d6 100644 --- a/pixi.lock +++ b/pixi.lock @@ -190,7 +190,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -501,7 +501,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zeromq-4.3.5-h4818236_10.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -802,7 +802,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zeromq-4.3.5-h507cc87_10.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -1125,7 +1125,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/06/8a/5e156e31ba656ce93c1cc895dd8f051ec351cb382940dca655aaec475005/python_bidi-0.6.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -1430,7 +1430,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zeromq-4.3.5-h4818236_10.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/07/c7/deb8c5e604404dbf10a3808a858946ca3547692ff6316b698945bb72177e/python_socketio-5.16.1-py3-none-any.whl @@ -1725,7 +1725,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zeromq-4.3.5-h507cc87_10.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/02/05/d60c732b56da5085175c07c74b2df4e6d181b0c9a61e1691474f06ef4b39/lxml-6.1.0-cp311-cp311-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -2048,7 +2048,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.23.1-pyhcf101f3_0.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -2359,7 +2359,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zeromq-4.3.5-h4818236_10.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -2660,7 +2660,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zeromq-4.3.5-h507cc87_10.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: . - - pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 + - pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 - pypi: https://files.pythonhosted.org/packages/00/bb/90ba423612b6aa0adccc6b1874bcd4a9b44b660c0c16f346611e00f64ac3/backrefs-7.0-py313-none-any.whl - pypi: https://files.pythonhosted.org/packages/01/7c/fa07d3da2b6253eb8474be16eab2eadf670460e364ccc895ca7ff388ee30/oscrypto-1.3.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl @@ -8301,8 +8301,9 @@ packages: name: easyreflectometry requires_dist: - bumps - - easyscience @ git+https://github.com/easyscience/corelib@develop + - easyscience @ git+https://github.com/easyscience/corelib.git@develop - orsopy + - plotly - pooch - refl1d>=1.0.0 - refnx @@ -8344,7 +8345,7 @@ packages: - validate-pyproject[all] ; extra == 'dev' - versioningit ; extra == 'dev' requires_python: '>=3.11' -- pypi: git+https://github.com/easyscience/corelib?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 +- pypi: git+https://github.com/easyscience/corelib.git?rev=develop#aadbd4891b94f6aa18187d48be8c2ab6f81113b0 name: easyscience version: 2.3.1+dev8 requires_dist: diff --git a/pyproject.toml b/pyproject.toml index ac183b21..596b7f27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ classifiers = [ ] requires-python = '>=3.11' dependencies = [ - 'easyscience @ git+https://github.com/easyscience/corelib.git@develop', + # 'easyscience @ git+https://github.com/easyscience/corelib.git@develop', + 'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_extend', # 'easyscience', 'scipp', 'refnx', diff --git a/src/easyreflectometry/analysis/bayesian.py b/src/easyreflectometry/analysis/bayesian.py index 39389a84..17ab252c 100644 --- a/src/easyreflectometry/analysis/bayesian.py +++ b/src/easyreflectometry/analysis/bayesian.py @@ -1332,8 +1332,9 @@ def load_posterior(path: str, skip: int = 0) -> 'PosteriorResults': """Reload a trace saved by :func:`save_posterior` into a :class:`PosteriorResults`. - The returned object's ``sampler_state`` can be fed back into - ``MultiFitter.mcmc_sample(..., resume_state=...)`` to extend the chain. + The returned object's ``sampler_state`` can be fed back into the core + ``Sampler`` (via ``Sampler.load_state(...)`` / ``Sampler.extend(...)``) + to extend the chain. :param path: File path prefix used in :func:`save_posterior`. :type path: str diff --git a/src/easyreflectometry/fitting.py b/src/easyreflectometry/fitting.py index 83a3f6eb..f90211ff 100644 --- a/src/easyreflectometry/fitting.py +++ b/src/easyreflectometry/fitting.py @@ -10,6 +10,7 @@ import scipp as sc from easyscience.fitting import AvailableMinimizers from easyscience.fitting import FitResults +from easyscience.fitting import Sampler from easyscience.fitting.multi_fitter import MultiFitter as EasyScienceMultiFitter from easyreflectometry.data import DataSet1D @@ -419,14 +420,22 @@ def mcmc_sample( y.append(y_eff) dy.append(weights) - # Delegate the actual BUMPS/DREAM sampling to the core MultiFitter + # Delegate the actual BUMPS/DREAM sampling to the core ``Sampler``. + # The core API moved from ``MultiFitter.mcmc_sample()`` to a dedicated + # ``Sampler`` class: construct it with the configured fitter and the + # bound data, then call ``sample()``. ``Sampler`` handles the + # multi-dataset reshaping internally. sampler_kwargs = {} if initializer is not None: sampler_kwargs['init'] = initializer - return self.easy_science_multi_fitter.mcmc_sample( + + sampler = Sampler( + self.easy_science_multi_fitter, x=x, y=y, weights=dy, + ) + results = sampler.sample( samples=samples, burn=burn, thin=thin, @@ -435,6 +444,12 @@ def mcmc_sample( progress_callback=progress_callback, abort_test=abort_test, ) + return { + 'draws': results.draws, + 'param_names': results.param_names, + 'state': results.state, + 'logp': results.logp, + } @property def chi2(self) -> float | None: diff --git a/tests/test_fitting.py b/tests/test_fitting.py index e4f938c8..2c519867 100644 --- a/tests/test_fitting.py +++ b/tests/test_fitting.py @@ -4,6 +4,7 @@ import os from unittest.mock import MagicMock +from unittest.mock import patch import numpy as np import pytest @@ -808,6 +809,44 @@ def _fake_fit(*, x, y, weights): # --------------------------------------------------------------------------- +def _fake_sampling_results(draws=None, param_names=None, state=None, logp=None): + """Build a stand-in for the core ``SamplingResults`` returned by ``Sampler.sample``.""" + res = MagicMock() + res.draws = np.ones((10, 2)) if draws is None else draws + res.param_names = ['a', 'b'] if param_names is None else param_names + res.state = state + res.logp = logp + return res + + +def _patch_sampler(capture, results=None): + """Patch ``easyreflectometry.fitting.Sampler`` and capture its call args. + + Records the constructor's ``(x, y, weights)`` and the ``sample()`` + hyperparameters into the ``capture`` dict, and returns ``results`` (a + fake ``SamplingResults``) from ``sample()``. + """ + results = results if results is not None else _fake_sampling_results() + + def _ctor(fitter, *, x, y, weights, **kwargs): + capture['fitter'] = fitter + capture['x'] = x + capture['y'] = y + capture['weights'] = weights + capture.update(kwargs) # e.g. sampler_kwargs if passed to the ctor + instance = MagicMock() + + def _sample(**sample_kwargs): + capture.update(sample_kwargs) + return results + + instance.sample = MagicMock(side_effect=_sample) + capture['instance'] = instance + return instance + + return patch('easyreflectometry.fitting.Sampler', side_effect=_ctor) + + class TestMCMCSampleRequiresBumpsEngine: """mcmc_sample() must raise when the core engine is not a BUMPS instance.""" @@ -824,138 +863,129 @@ def test_raises_runtime_error_when_not_bumps(self): with pytest.raises(RuntimeError, match='Bayesian sampling requires a BUMPS minimizer'): fitter.mcmc_sample(data) - def test_wrapper_check_runs_before_core_mcmc_sample(self): - """The wrapper-level guard must fire before delegating to the core sampler. + def test_wrapper_check_runs_before_sampler(self): + """The wrapper-level guard must fire before constructing the core ``Sampler``. - Replace the core ``mcmc_sample`` with a sentinel that would record any call; - the guard should raise without invoking it. + Patch ``Sampler`` with a sentinel that would record any instantiation; + the guard should raise without ever building it. """ model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) # default minimizer is LMFit, not BUMPS - core_called = {'count': 0} - - def _should_not_be_called(**_kwargs): - core_called['count'] += 1 - return {'draws': np.empty((0, 0)), 'param_names': [], 'state': None, 'logp': None} - - fitter.easy_science_multi_fitter.mcmc_sample = _should_not_be_called + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - with pytest.raises(RuntimeError, match='Bayesian sampling requires a BUMPS minimizer'): - fitter.mcmc_sample(data) - assert core_called['count'] == 0 + with _patch_sampler(capture) as sampler_cls: + with pytest.raises(RuntimeError, match='Bayesian sampling requires a BUMPS minimizer'): + fitter.mcmc_sample(data) + sampler_cls.assert_not_called() class TestMCMCSampleBasic: """Basic mcmc_sample() dispatch and return-value forwarding.""" - def test_returns_core_result_dict(self): - """mcmc_sample() returns whatever the core MultiFitter.mcmc_sample() returns.""" + def test_returns_result_dict_from_sampler(self): + """mcmc_sample() returns a dict built from the core Sampler's SamplingResults.""" model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) - # Mock the core MultiFitter.mcmc_sample to return a known dict - fake_result = {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(return_value=fake_result) + + draws = np.ones((10, 2)) + sentinel_state = object() + logp = np.zeros(10) + results = _fake_sampling_results(draws=draws, param_names=['a', 'b'], state=sentinel_state, logp=logp) + + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - result = fitter.mcmc_sample(data, samples=100, burn=20, thin=2, population=5) - assert result is fake_result + with _patch_sampler(capture, results=results): + result = fitter.mcmc_sample(data, samples=100, burn=20, thin=2, population=5) + + # The fitter passed to Sampler is the core MultiFitter + assert capture['fitter'] is fitter.easy_science_multi_fitter + assert result['draws'] is draws + assert result['param_names'] == ['a', 'b'] + assert result['state'] is sentinel_state + assert result['logp'] is logp - def test_forwards_hyperparams_to_core(self): - """Samples, burn, thin, population, chains are forwarded to core.""" + def test_forwards_hyperparams_to_sampler(self): + """Samples, burn, thin, population are forwarded to Sampler.sample().""" model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) - captured = {} - - def _fake_mcmc_sample(*, x, y, weights, samples, burn, thin, population, **kwargs): - captured['samples'] = samples - captured['burn'] = burn - captured['thin'] = thin - captured['population'] = population - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) + + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - fitter.mcmc_sample(data, samples=500, burn=100, thin=5, population=8) - assert captured['samples'] == 500 - assert captured['burn'] == 100 - assert captured['thin'] == 5 - assert captured['population'] == 8 + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=500, burn=100, thin=5, population=8) + assert capture['samples'] == 500 + assert capture['burn'] == 100 + assert capture['thin'] == 5 + assert capture['population'] == 8 - def test_forwards_population_to_core(self): - """'population' argument is forwarded to core.""" + def test_forwards_population_to_sampler(self): + """'population' argument is forwarded to Sampler.sample().""" model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) - captured = {} - - def _fake_mcmc_sample(*, x, y, weights, population, **kwargs): - captured['population'] = population - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) + + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - fitter.mcmc_sample(data, samples=100, burn=20, thin=2, population=6) - assert captured['population'] == 6 + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=100, burn=20, thin=2, population=6) + assert capture['population'] == 6 class TestMCMCSampleInitializer: """initializer parameter is forwarded via sampler_kwargs.""" def test_initializer_passed_as_sampler_kwargs_init(self): - """initializer='lhs' should be passed as sampler_kwargs={'init': 'lhs'} to core.""" + """initializer='lhs' should be passed as sampler_kwargs={'init': 'lhs'} to Sampler.sample().""" model = Model() model.interface = CalculatorFactory() fitter = MultiFitter(model) - captured = {} - - def _fake_mcmc_sample(*, sampler_kwargs, **kwargs): - captured['sampler_kwargs'] = sampler_kwargs - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) + + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - fitter.mcmc_sample(data, samples=100, burn=20, thin=2, initializer='lhs') - assert captured['sampler_kwargs'] == {'init': 'lhs'} + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=100, burn=20, thin=2, initializer='lhs') + assert capture['sampler_kwargs'] == {'init': 'lhs'} def test_initializer_none_omits_sampler_kwargs(self): """When initializer is None, sampler_kwargs should be None, not an empty dict.""" @@ -963,23 +993,19 @@ def test_initializer_none_omits_sampler_kwargs(self): model.interface = CalculatorFactory() fitter = MultiFitter(model) - captured = {} - - def _fake_mcmc_sample(*, sampler_kwargs, **kwargs): - captured['sampler_kwargs'] = sampler_kwargs - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} - fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) + + capture = {} data = sc.DataGroup({ 'coords': {'Qz_0': sc.array(dims=['Qz_0'], values=np.linspace(0.01, 0.3, 10))}, 'data': {'R_0': sc.array(dims=['Qz_0'], values=np.ones(10), variances=np.ones(10) * 0.01)}, }) - fitter.mcmc_sample(data, samples=100, burn=20, thin=2) - assert captured['sampler_kwargs'] is None + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=100, burn=20, thin=2) + assert capture['sampler_kwargs'] is None class TestMCMCSampleZeroVariance: @@ -994,17 +1020,10 @@ def test_hybrid_transforms_zero_variance_points(self): # Use legacy_mask so zero-variance points are dropped fitter = MultiFitter(model, objective='legacy_mask') - captured = {} - - def _fake_mcmc_sample(*, x, y, weights, **kwargs): - captured['x'] = x - captured['y'] = y - captured['weights'] = weights - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + capture = {} fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) qz = np.linspace(0.01, 0.3, 10) r = np.exp(-qz * 50) @@ -1018,12 +1037,13 @@ def _fake_mcmc_sample(*, x, y, weights, **kwargs): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - fitter.mcmc_sample(data, samples=100, burn=20, thin=2) + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=100, burn=20, thin=2) # legacy_mask should drop the 2 zero-variance points - assert len(captured['x'][0]) == 8 - assert len(captured['y'][0]) == 8 - assert len(captured['weights'][0]) == 8 + assert len(capture['x'][0]) == 8 + assert len(capture['y'][0]) == 8 + assert len(capture['weights'][0]) == 8 mask_warnings = [str(ww.message) for ww in w if 'Masked' in str(ww.message)] assert len(mask_warnings) == 1 @@ -1037,16 +1057,10 @@ def test_per_call_objective_override(self): model.interface = CalculatorFactory() fitter = MultiFitter(model, objective='legacy_mask') # default - captured = {} - - def _fake_mcmc_sample(*, x, y, weights, **kwargs): - captured['x'] = x - captured['y'] = y - return {'draws': np.ones((10, 2)), 'param_names': ['a', 'b'], 'state': None, 'logp': None} + capture = {} fitter.easy_science_multi_fitter = MagicMock() fitter.easy_science_multi_fitter.minimizer.package = 'bumps' - fitter.easy_science_multi_fitter.mcmc_sample = MagicMock(side_effect=_fake_mcmc_sample) qz = np.linspace(0.01, 0.3, 10) r = np.exp(-qz * 50) @@ -1061,6 +1075,7 @@ def _fake_mcmc_sample(*, x, y, weights, **kwargs): # Override to hybrid — should keep all 10 points with warnings.catch_warnings(record=True): warnings.simplefilter('always') - fitter.mcmc_sample(data, samples=100, burn=20, thin=2, objective='hybrid') + with _patch_sampler(capture): + fitter.mcmc_sample(data, samples=100, burn=20, thin=2, objective='hybrid') - assert len(captured['x'][0]) == 10 # all points kept (Mighell-substituted) + assert len(capture['x'][0]) == 10 # all points kept (Mighell-substituted)