diff --git a/skore-hub-project/src/skore_hub_project/__init__.py b/skore-hub-project/src/skore_hub_project/__init__.py index c54509ef08..cec54c4f5b 100644 --- a/skore-hub-project/src/skore_hub_project/__init__.py +++ b/skore-hub-project/src/skore_hub_project/__init__.py @@ -6,14 +6,11 @@ from logging import basicConfig, getLogger from matplotlib import pyplot as plt -from rich.console import Console -from rich.theme import Theme __all__ = [ "Payload", "b64_str_to_bytes", "bytes_to_b64_str", - "console", "switch_plt_backend", ] @@ -21,17 +18,6 @@ basicConfig() logger = getLogger(__name__) -console = Console( - width=88, - theme=Theme( - { - "repr.str": "cyan", - "rule.line": "orange1", - "repr.url": "orange1", - } - ), -) - def b64_str_to_bytes(literal: str) -> bytes: """Decode the Base64 str object ``literal`` in a bytes.""" diff --git a/skore-hub-project/src/skore_hub_project/artifact/upload.py b/skore-hub-project/src/skore_hub_project/artifact/upload.py index 554a3e0b3f..8ea880c9fc 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/upload.py +++ b/skore-hub-project/src/skore_hub_project/artifact/upload.py @@ -2,11 +2,12 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor, as_completed from math import ceil from pathlib import Path from typing import TYPE_CHECKING +from joblib import Parallel, delayed + from ..client.client import Client, HUBClient from .serializer import Serializer @@ -99,7 +100,6 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str: Serializer(content) as serializer, HUBClient() as hub_client, Client() as standard_client, - ThreadPoolExecutor() as pool, ): # Ask for upload urls. response = hub_client.post( @@ -116,7 +116,8 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str: # An empty response means that an artifact with the same checksum already # exists. The content doesn't have to be re-uploaded. if urls := response.json(): - task_to_chunk_id = {} + chunk_ids = [] + tasks = [] # Upload each chunk of the serialized content to the artifacts storage, # using a disk temporary file. @@ -126,38 +127,35 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str: # # Use `threading` over `asyncio` to ensure compatibility with Jupyter # notebooks, where the event loop is already running. + for url in urls: chunk_id = url["chunk_id"] or 1 - task = pool.submit( - upload_chunk, - filepath=serializer.filepath, - client=standard_client, - url=url["upload_url"], - offset=((chunk_id - 1) * CHUNK_SIZE), - length=CHUNK_SIZE, - content_type=( - content_type if len(urls) == 1 else "application/octet-stream" - ), - ) - - task_to_chunk_id[task] = chunk_id - try: - etags = dict( - sorted( - ( - task_to_chunk_id[task], - task.result(), - ) - for task in as_completed(task_to_chunk_id) + chunk_ids.append(chunk_id) + tasks.append( + delayed(upload_chunk)( + filepath=serializer.filepath, + client=standard_client, + url=url["upload_url"], + offset=((chunk_id - 1) * CHUNK_SIZE), + length=CHUNK_SIZE, + content_type=( + content_type + if len(urls) == 1 + else "application/octet-stream" + ), ) ) - except BaseException: - # Cancel all remaining tasks, especially on `KeyboardInterrupt`. - for task in task_to_chunk_id: - task.cancel() - raise + etags = dict( + sorted( + zip( + chunk_ids, + Parallel(backend="threading")(tasks), + strict=True, + ) + ) + ) # Acknowledge the upload, to let the hub/storage rebuild the whole. hub_client.post( diff --git a/skore-hub-project/src/skore_hub_project/authentication/login.py b/skore-hub-project/src/skore_hub_project/authentication/login.py index a819f3bc57..5242e038db 100644 --- a/skore-hub-project/src/skore_hub_project/authentication/login.py +++ b/skore-hub-project/src/skore_hub_project/authentication/login.py @@ -6,8 +6,8 @@ from rich.align import Align from rich.live import Live from rich.panel import Panel +from skore import THREADABLE, console -from skore_hub_project import console from skore_hub_project.authentication.apikey import APIKey from skore_hub_project.authentication.token import Token from skore_hub_project.authentication.uri import URI @@ -44,7 +44,7 @@ def login(*, timeout: int = 600) -> None: try: credentials = APIKey() except KeyError: - with Live(console=console, auto_refresh=False) as live: + with Live(console=console, auto_refresh=THREADABLE) as live: credentials = Token(timeout=timeout, live=live) live.update( diff --git a/skore-hub-project/src/skore_hub_project/authentication/token.py b/skore-hub-project/src/skore_hub_project/authentication/token.py index 3d63578e1d..04e89e74fc 100644 --- a/skore-hub-project/src/skore_hub_project/authentication/token.py +++ b/skore-hub-project/src/skore_hub_project/authentication/token.py @@ -12,8 +12,8 @@ from rich.align import Align from rich.live import Live from rich.panel import Panel +from skore import console -from skore_hub_project import console from skore_hub_project.authentication.uri import URI diff --git a/skore-hub-project/src/skore_hub_project/client/client.py b/skore-hub-project/src/skore_hub_project/client/client.py index 019a1a0428..f92b3ead6d 100644 --- a/skore-hub-project/src/skore_hub_project/client/client.py +++ b/skore-hub-project/src/skore_hub_project/client/client.py @@ -179,6 +179,7 @@ def __init__( if JUPYTERLITE and (transport is None): from httpx import Request from js import XMLHttpRequest + from pyodide.ffi import to_js from pyodide.http.pyxhr import XHRResponse class JupyterliteTransport(BaseTransport): @@ -188,11 +189,9 @@ def handle_request(self, request: Request) -> Response: req.withCredentials = True for name, value in request.headers.items(): - if name.lower() == "host": - continue req.setRequestHeader(name, value) - req.send(request.content) + req.send(to_js(request.read())) xhr = XHRResponse(req) diff --git a/skore-hub-project/src/skore_hub_project/project/project.py b/skore-hub-project/src/skore_hub_project/project/project.py index 600feeb42a..430585d95a 100644 --- a/skore-hub-project/src/skore_hub_project/project/project.py +++ b/skore-hub-project/src/skore_hub_project/project/project.py @@ -17,8 +17,9 @@ from joblib import load as joblib_load from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn from sklearn.utils.validation import _check_pos_label_consistency +from skore import THREADABLE, console -from skore_hub_project import console, switch_plt_backend +from skore_hub_project import switch_plt_backend from skore_hub_project.client.client import Client, HUBClient from skore_hub_project.exception import ForbiddenException, NotFoundException from skore_hub_project.json import dumps @@ -294,6 +295,7 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None TextColumn("{task.description}"), TimeElapsedColumn(), console=console, + auto_refresh=THREADABLE, ) as progress, ): task = progress.add_task( diff --git a/skore-hub-project/src/skore_hub_project/report/report.py b/skore-hub-project/src/skore_hub_project/report/report.py index 13a6887b02..338cd4fdcd 100644 --- a/skore-hub-project/src/skore_hub_project/report/report.py +++ b/skore-hub-project/src/skore_hub_project/report/report.py @@ -3,14 +3,15 @@ from __future__ import annotations from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed from functools import cached_property, partial +from operator import methodcaller from typing import ClassVar, Generic, TypeVar, cast +from joblib import Parallel, delayed from pydantic import BaseModel, ConfigDict, Field, computed_field from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn +from skore import THREADABLE, console -from skore_hub_project import console from skore_hub_project.artifact.media.media import Media from skore_hub_project.artifact.pickle import Pickle from skore_hub_project.metric.metric import Metric @@ -29,6 +30,7 @@ TimeElapsedColumn(), console=console, transient=True, + auto_refresh=THREADABLE, ) Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport)) @@ -109,22 +111,17 @@ def metrics(self) -> list[Metric[Report]]: self.report.cache_predictions() metrics = [metric_cls(report=self.report) for metric_cls in self.METRICS] + tasks = list(map(delayed(methodcaller("compute")), metrics)) - with SkinnedProgress() as progress, ThreadPoolExecutor() as pool: - tasks = [ - pool.submit(lambda metric: metric.compute(), metric) - for metric in metrics - ] - - for task in progress.track( - as_completed(tasks), + with SkinnedProgress() as progress: + for _ in progress.track( + Parallel(backend="threading")(tasks), description=( f"Computing {self.report.__class__.__name__} " f"#{self.report._hash} metrics" ), total=len(tasks), ): - task.result() progress.refresh() return [metric for metric in metrics if metric.value is not None] diff --git a/skore/src/skore/__init__.py b/skore/src/skore/__init__.py index 14e06e4fc4..61c348255c 100644 --- a/skore/src/skore/__init__.py +++ b/skore/src/skore/__init__.py @@ -2,6 +2,7 @@ from importlib.metadata import version from logging import INFO, NullHandler, getLogger +from threading import Thread from warnings import warn from joblib import __version__ as joblib_version @@ -37,6 +38,7 @@ from skore._sklearn._plot.inspection.permutation_importance import ( PermutationImportanceDisplay, ) +from skore._utils._environment import is_environment_notebook_like from skore._utils._patch import setup_jupyter_display from skore._utils._show_versions import show_versions @@ -54,28 +56,31 @@ ) +__version__ = version("skore") __all__ = [ "Check", - "CoefficientsDisplay", - "DiagnosticDisplay", "CheckNotApplicable", + "CoefficientsDisplay", "ComparisonReport", - "compare", "ConfusionMatrixDisplay", "CrossValidationReport", + "DiagnosticDisplay", "Display", "EstimatorReport", - "evaluate", "ImpurityDecreaseDisplay", - "TrainTestSplit", "MetricsSummaryDisplay", "PermutationImportanceDisplay", "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", "Project", "RocCurveDisplay", + "THREADABLE", "TableReportDisplay", + "TrainTestSplit", + "compare", "configuration", + "console", + "evaluate", "login", "show_versions", "train_test_split", @@ -87,14 +92,19 @@ logger.setLevel(INFO) -skore_console_theme = Theme( - { - "repr.str": "cyan", - "rule.line": "orange1", - "repr.url": "orange1", - } +console = Console( + width=88, + theme=Theme({"repr.str": "cyan", "rule.line": "orange1", "repr.url": "orange1"}), + # ... + force_jupyter=(is_environment_notebook_like() or None), ) -console = Console(theme=skore_console_theme, width=88) -__version__ = version("skore") +try: + thread = Thread() + thread.start() + thread.join() +except Exception: + THREADABLE = False +else: + THREADABLE = True diff --git a/skore/src/skore/_utils/_environment.py b/skore/src/skore/_utils/_environment.py index db6b81ce35..905790b550 100644 --- a/skore/src/skore/_utils/_environment.py +++ b/skore/src/skore/_utils/_environment.py @@ -22,16 +22,21 @@ def get_environment_info() -> dict[str, Any]: try: # get_ipython() is defined when running in Jupyter or IPython # there is no need to import IPython here - shell = get_ipython().__class__.__name__ # type: ignore + ipython = get_ipython() # type: ignore[name-defined] + except NameError: + pass + else: + shell = ipython.__class__.__name__ + env_info["details"]["ipython_shell"] = shell - if shell == "ZMQInteractiveShell": # Jupyter notebook/lab + # Jupyter notebook/lab or Jupyterlite + if (shell == "ZMQInteractiveShell") or ("pyodide" in str(ipython.__class__)): env_info["is_jupyter"] = True env_info["environment_name"] = "jupyter" - elif shell == "TerminalInteractiveShell": # IPython terminal + # IPython terminal + elif shell == "TerminalInteractiveShell": env_info["environment_name"] = "ipython_terminal" - except NameError: - pass if "VSCODE_PID" in os.environ: env_info["is_vscode"] = True diff --git a/skore/src/skore/_utils/_progress_bar.py b/skore/src/skore/_utils/_progress_bar.py index 4dfc6bf264..3c6c4bdff7 100644 --- a/skore/src/skore/_utils/_progress_bar.py +++ b/skore/src/skore/_utils/_progress_bar.py @@ -2,14 +2,7 @@ from operator import length_hint from typing import Any, TypeVar -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, -) - -from skore._config import configuration +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn T = TypeVar("T") @@ -18,6 +11,9 @@ class ProgressBar: """Simplified progress bar based on ``rich.Progress``.""" def __init__(self, description: str, total: float | None): + from skore import THREADABLE, console + from skore._config import configuration + progress = Progress( SpinnerColumn(), TextColumn("[bold cyan]{task.description}"), @@ -27,8 +23,10 @@ def __init__(self, description: str, total: float | None): pulse_style="orange1", ), TextColumn("[orange1]{task.percentage:>3.0f}%"), + console=console, expand=False, transient=True, + auto_refresh=THREADABLE, disable=(not configuration.show_progress), )