Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 132 additions & 25 deletions docs/docs/tutorials/fitting-bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib widget"
"%matplotlib inline"
]
},
{
Expand All @@ -28,7 +28,7 @@
"\n",
"where $\\theta$ are the model parameters, $d$ is the observed data, $p(d \\mid \\theta)$ is the likelihood, and $p(\\theta)$ is the prior. In `easyscience`, the `min`/`max` bounds of a `Parameter` are interpreted as a **uniform prior**, and a Gaussian likelihood is constructed from the data and supplied weights.\n",
"\n",
"`easyscience` exposes a Bayesian Markov-chain Monte Carlo (MCMC) sampler through the `Fitter.mcmc_sample` method. Under the hood this uses BUMPS' DREAM sampler, so the underlying minimizer must be switched to BUMPS.\n",
"`easyscience` exposes a Bayesian Markov-chain Monte Carlo (MCMC) sampler through the `Sampler` class. Under the hood this uses BUMPS' DREAM sampler, so the underlying minimizer must be switched to BUMPS.\n",
"\n",
"```{note}\n",
"This tutorial focuses on Bayesian analysis with a simple QENS model for illustration. For dedicated QENS fitting with more sophisticated models, consider using [`EasyDynamics`](https://github.com/easyscience/easydynamics).\n",
Expand All @@ -45,7 +45,7 @@
"\n",
"The trade-off is computational cost: MCMC requires thousands of model evaluations, whereas an MLE fit may converge in dozens. For the simple 4-parameter model used here the difference is negligible, but for expensive models it is worth starting with MLE and only switching to MCMC when you need the richer output.\n",
"\n",
"In this tutorial we re-use the QENS dataset and the Lorentzian-with-resolution model from the [Fitting QENS](fitting-qens.ipynb) tutorial, but instead of returning a single best-fit value with a symmetric error bar we will draw thousands of samples from the posterior.\n"
"In this tutorial we re-use the QENS dataset and the Lorentzian-with-resolution model from the [Fitting QENS](fitting-qens.ipynb) tutorial, but instead of returning a single best-fit value with a symmetric error bar we will draw thousands of samples from the posterior."
]
},
{
Expand Down Expand Up @@ -250,20 +250,20 @@
"\n",
"We now draw samples from the posterior distribution $p(\\theta \\mid d)$ using the BUMPS DREAM (DiffeRential Evolution Adaptive Metropolis) algorithm. DREAM is an ensemble MCMC method that runs multiple chains in parallel and automatically tunes the proposal distribution.\n",
"\n",
"DREAM only works with the BUMPS minimizer. We reuse the ``mle_fitter`` created above, switch it to BUMPS, and call `Fitter.mcmc_sample`, which returns a dictionary with the following keys:\n",
"DREAM only works with the BUMPS minimizer. We reuse the ``mle_fitter`` created above, switch it to BUMPS, and create a `Sampler` instance bound to the fitter and data. Calling `sampler.sample()` returns a `SamplingResults` object with the following attributes:\n",
"\n",
"- `draws`: a `(n_samples, n_parameters)` array of posterior samples: each **row** is one complete draw from the joint posterior (one value for every parameter simultaneously), and each **column** holds all sampled values for a single parameter;\n",
"- `param_names`: the unique names of the parameters, in the same column order as `draws`;\n",
"- `internal_bumps_object`: the underlying BUMPS `MCMCDraw` object, useful for advanced diagnostics;\n",
"- `state`: the underlying BUMPS `MCMCDraw` object, useful for advanced diagnostics;\n",
"- `logp`: the log-posterior of each retained sample.\n",
"\n",
"The key sampling parameters are:\n",
"\n",
"- `samples` (4000): the total number of posterior draws to generate across all chains;\n",
"- `samples` (10000): the total number of posterior draws to generate across all chains;\n",
"- `burn` (500): the number of initial *burn-in* iterations to discard — the sampler needs time to find the typical set of the posterior, and early samples are not representative;\n",
"- `thin` (2): the *thinning* interval — only every second sample is kept, which reduces autocorrelation between consecutive draws;\n",
"\n",
"First, we switch to the BUMPS minimizer:\n"
"First, we switch to the BUMPS minimizer:"
]
},
{
Expand All @@ -285,17 +285,13 @@
"metadata": {},
"outputs": [],
"source": [
"result = mle_fitter.mcmc_sample(\n",
" x=omega,\n",
" y=intensity_obs,\n",
" weights=1 / intensity_error,\n",
" samples=10000,\n",
" burn=500,\n",
" thin=2,\n",
")\n",
"from easyscience.fitting import Sampler\n",
"\n",
"sampler = Sampler(mle_fitter, omega, intensity_obs, weights=1 / intensity_error)\n",
"results = sampler.sample(samples=10000, burn=500, thin=2)\n",
"\n",
"print(f'Drew {result[\"draws\"].shape[0]} samples for {result[\"draws\"].shape[1]} parameters.')\n",
"print('parameters:', result['param_names'])"
"print(f'Drew {results.draws.shape[0]} samples for {results.draws.shape[1]} parameters.')\n",
"print('parameters:', results.param_names)"
]
},
{
Expand All @@ -320,12 +316,9 @@
"metadata": {},
"outputs": [],
"source": [
"draws = result['draws']\n",
"logp = result['logp']\n",
"if callable(logp):\n",
" _, logp = logp()\n",
" logp = logp.flatten()\n",
"name_to_col = {name: idx for idx, name in enumerate(result['param_names'])}\n",
"draws = results.draws\n",
"logp = results.logp\n",
"name_to_col = {name: idx for idx, name in enumerate(results.param_names)}\n",
"\n",
"\n",
"def column_for(parameter):\n",
Expand Down Expand Up @@ -367,7 +360,7 @@
"The MLE finds the single point that maximises the likelihood, while the Bayesian median is the central value of the *posterior* — which also accounts for the prior $p(\\theta)$. If a parameter's posterior is skewed (asymmetric), the median and the mode (which approximates the MLE) will not coincide. The table below shows both so you can spot any such differences.\n",
"\n",
"\n",
"Note that the columns of `result['draws']` are ordered by `result['param_names']` (which use the parameters' `unique_name`), so it is worth building a small helper to look up a column by friendly name.\n"
"Note that the columns of `results.draws` are ordered by `results.param_names` (which use the parameters' `unique_name`), so it is worth building a small helper to look up a column by friendly name."
]
},
{
Expand Down Expand Up @@ -492,11 +485,125 @@
"ax.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "03339658",
"metadata": {},
"source": [
"## Extend the chain and check convergence\n",
"\n",
"The original run used ``samples=10000`` (5000 retained after thinning). Here we **extend the chain by 5000 more raw samples** using ``sampler.extend()`` — DREAM continues from the previous in-memory state instead of starting cold.\n",
"\n",
"``sampler.extend(additional_samples=5000)`` does the ring-buffer arithmetic for you automatically: it reads the number of stored generations from the saved state and sizes the new buffer to ``old_generations + additional_samples`` so no existing draws are lost.\n",
"\n",
"After the extension we compare the posterior summaries and check convergence with Gelman-Rubin R-hat."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "293b140b",
"metadata": {},
"outputs": [],
"source": [
"# Extend the existing chain by 5000 more raw samples.\n",
"# extend() handles the ring-buffer arithmetic automatically.\n",
"extended_results = sampler.extend(additional_samples=5000, thin=2)\n",
"\n",
"short_draws = results.draws\n",
"extended_draws = extended_results.draws\n",
"\n",
"print(f'Short chain: {short_draws.shape[0]} retained draws')\n",
"print(f'Extended chain: {extended_draws.shape[0]} retained draws')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ec4302c",
"metadata": {},
"outputs": [],
"source": [
"# Build helpers to look up parameter columns.\n",
"name_to_col_ext = {name: idx for idx, name in enumerate(extended_results.param_names)}\n",
"name_to_col_orig = {name: idx for idx, name in enumerate(results.param_names)}\n",
"\n",
"\n",
"def col_orig(par):\n",
" return results.draws[:, name_to_col_orig[par.unique_name]]\n",
"\n",
"\n",
"def col_ext(par):\n",
" return extended_results.draws[:, name_to_col_ext[par.unique_name]]\n",
"\n",
"\n",
"# Side-by-side posterior summary: compare the first 5000 draws (before extension)\n",
"# with the FULL extended set.\n",
"print(f'{\"param\":<10s} {\"metric\":>12s} {\"first 5k\":>12s} {\"full ext\":>12s} diff')\n",
"print('-' * 65)\n",
"for label, par in (('area', area), ('gamma', gamma), ('omega_0', omega_0), ('sigma', sigma)):\n",
" c_first = col_orig(par) # first 5000 draws\n",
" c_full = col_ext(par) # full extended ~7500 draws\n",
" for metric, fn in [\n",
" ('mean', np.mean),\n",
" ('std', np.std),\n",
" ('q2.5%', lambda c: np.percentile(c, 2.5)),\n",
" ('q50%', lambda c: np.percentile(c, 50)),\n",
" ('q97.5%', lambda c: np.percentile(c, 97.5)),\n",
" ]:\n",
" vf = fn(c_first)\n",
" vx = fn(c_full)\n",
" diff = vx - vf\n",
" print(f'{label:<10s} {metric:>12s} {vf:12.4g} {vx:12.4g} {diff:+.2e}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0f30be6",
"metadata": {},
"outputs": [],
"source": [
"# Visual comparison: overlay the first 5000 draws with the extended set.\n",
"fig, axes = plt.subplots(1, 4, figsize=(14, 3))\n",
"for ax, label, par in zip(\n",
" axes, ('area', 'gamma', 'omega_0', 'sigma'), (area, gamma, omega_0, sigma)\n",
"):\n",
" c_first = col_orig(par)\n",
" c_full = col_ext(par)\n",
" ax.hist(c_first, bins=40, density=True, alpha=0.5, color='C0', label=f'first 5k')\n",
" ax.hist(\n",
" c_full, bins=40, density=True, alpha=0.5, color='C3', label=f'extended ({len(c_full)})'\n",
" )\n",
" ax.set_title(label)\n",
" ax.set_yticks([])\n",
"axes[0].legend(fontsize=9)\n",
"fig.suptitle('Marginal posterior: first 5000 draws vs extended chain')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3449e0a7",
"metadata": {},
"outputs": [],
"source": [
"# Convergence diagnostic: Gelman-Rubin R-hat from the extended chain state.\n",
"# Values close to 1.0 indicate good convergence.\n",
"print('Gelman-Rubin R-hat (extended chain) — values < 1.05 indicate convergence:')\n",
"rhat = extended_results.state.gelman()\n",
"for name, val in zip(extended_results.param_names, rhat):\n",
" status = '✓' if val < 1.05 else '?' if val < 1.1 else '✗'\n",
" print(f' {name:<20s} {val:.4f} {status}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "p312",
"display_name": "era",
"language": "python",
"name": "python3"
},
Expand Down
Loading
Loading