Skip to content

Commit 023ff19

Browse files
junpenglaoclaude
andauthored
Make fastprogress an optional dependency (#853)
* Make fastprogress an optional dependency fastprogress 1.1.0+ introduced an unconditional top-level IPython import and a heavy python-fasthtml transitive dependency (starlette, uvicorn, etc.), breaking imports in environments without IPython installed. Fixes #812, addresses PR #852: - Remove fastprogress and ipython from hard dependencies - Add fastprogress>=1.0.0,<1.1 as optional extra: pip install "blackjax[progress]" - Lazy-import fastprogress inside _update_bar() so it is only required when progress_bar=True is actually used Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * ci: install blackjax[progress] in test env Tests that use progress_bar=True need fastprogress, which is now an optional dependency. Install it via the [progress] extra. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ffca7c8 commit 023ff19

3 files changed

Lines changed: 11 additions & 4 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- name: Set up test environment
3535
run: |
3636
python -m pip install --upgrade pip
37-
pip install .
37+
pip install ".[progress]"
3838
less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {}
3939
- name: Run tests
4040
run: |

blackjax/progress_bar.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""
1717
from threading import Lock
1818

19-
from fastprogress.fastprogress import progress_bar
2019
from jax import lax
2120
from jax.experimental import io_callback
2221
from jax.numpy import array
@@ -42,6 +41,13 @@ def _calc_chain_idx(iter_num):
4241
return idx
4342

4443
def _update_bar(arg, chain_id):
44+
try:
45+
from fastprogress.fastprogress import progress_bar
46+
except ImportError as e:
47+
raise ImportError(
48+
"fastprogress is required to use progress bars. "
49+
"Install it with: pip install fastprogress"
50+
) from e
4551
chain_id = int(chain_id)
4652
if arg == 0:
4753
chain_id = _calc_chain_idx(arg)

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ classifiers = [
3131
"Topic :: Scientific/Engineering :: Mathematics",
3232
]
3333
dependencies = [
34-
"fastprogress>=1.0.0",
35-
"ipython",
3634
"jax>=0.9.0",
3735
"jaxlib>=0.9.0",
3836
"numpy>=1.25",
@@ -42,6 +40,9 @@ dependencies = [
4240
]
4341
dynamic = ["version"]
4442

43+
[project.optional-dependencies]
44+
progress = ["fastprogress>=1.0.0,<1.1"]
45+
4546
[project.urls]
4647
homepage = "https://github.com/blackjax-devs/blackjax"
4748
documentation = "https://blackjax-devs.github.io/blackjax/"

0 commit comments

Comments
 (0)