Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions skore-hub-project/src/skore_hub_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,18 @@
from logging import basicConfig, getLogger

from matplotlib import pyplot as plt
from rich.console import Console
from rich.theme import Theme

__all__ = [
"Payload",
"b64_str_to_bytes",
"bytes_to_b64_str",
"console",
"switch_plt_backend",
]


basicConfig()
logger = getLogger(__name__)

console = Console(
width=88,
theme=Theme(
{
"repr.str": "cyan",
"rule.line": "orange1",
"repr.url": "orange1",
}
),
)


def b64_str_to_bytes(literal: str) -> bytes:
"""Decode the Base64 str object ``literal`` in a bytes."""
Expand Down
56 changes: 27 additions & 29 deletions skore-hub-project/src/skore_hub_project/artifact/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, as_completed
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING

from joblib import Parallel, delayed

from ..client.client import Client, HUBClient
from .serializer import Serializer

Expand Down Expand Up @@ -99,7 +100,6 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
Serializer(content) as serializer,
HUBClient() as hub_client,
Client() as standard_client,
ThreadPoolExecutor() as pool,
):
# Ask for upload urls.
response = hub_client.post(
Expand All @@ -116,7 +116,8 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
# An empty response means that an artifact with the same checksum already
# exists. The content doesn't have to be re-uploaded.
if urls := response.json():
task_to_chunk_id = {}
chunk_ids = []
tasks = []

# Upload each chunk of the serialized content to the artifacts storage,
# using a disk temporary file.
Expand All @@ -126,38 +127,35 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
#
# Use `threading` over `asyncio` to ensure compatibility with Jupyter
# notebooks, where the event loop is already running.

for url in urls:
chunk_id = url["chunk_id"] or 1
task = pool.submit(
upload_chunk,
filepath=serializer.filepath,
client=standard_client,
url=url["upload_url"],
offset=((chunk_id - 1) * CHUNK_SIZE),
length=CHUNK_SIZE,
content_type=(
content_type if len(urls) == 1 else "application/octet-stream"
),
)

task_to_chunk_id[task] = chunk_id

try:
etags = dict(
sorted(
(
task_to_chunk_id[task],
task.result(),
)
for task in as_completed(task_to_chunk_id)
chunk_ids.append(chunk_id)
tasks.append(
delayed(upload_chunk)(
filepath=serializer.filepath,
client=standard_client,
url=url["upload_url"],
offset=((chunk_id - 1) * CHUNK_SIZE),
length=CHUNK_SIZE,
content_type=(
content_type
if len(urls) == 1
else "application/octet-stream"
),
)
)
except BaseException:
# Cancel all remaining tasks, especially on `KeyboardInterrupt`.
for task in task_to_chunk_id:
task.cancel()

raise
etags = dict(
sorted(
zip(
chunk_ids,
Parallel(backend="threading")(tasks),
strict=True,
)
)
)

# Acknowledge the upload, to let the hub/storage rebuild the whole.
hub_client.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from rich.align import Align
from rich.live import Live
from rich.panel import Panel
from skore import THREADABLE, console

from skore_hub_project import console
from skore_hub_project.authentication.apikey import APIKey
from skore_hub_project.authentication.token import Token
from skore_hub_project.authentication.uri import URI
Expand Down Expand Up @@ -44,7 +44,7 @@ def login(*, timeout: int = 600) -> None:
try:
credentials = APIKey()
except KeyError:
with Live(console=console, auto_refresh=False) as live:
with Live(console=console, auto_refresh=THREADABLE) as live:
credentials = Token(timeout=timeout, live=live)

live.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from rich.align import Align
from rich.live import Live
from rich.panel import Panel
from skore import console

from skore_hub_project import console
from skore_hub_project.authentication.uri import URI


Expand Down
5 changes: 2 additions & 3 deletions skore-hub-project/src/skore_hub_project/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
if JUPYTERLITE and (transport is None):
from httpx import Request
from js import XMLHttpRequest
from pyodide.ffi import to_js
from pyodide.http.pyxhr import XHRResponse

class JupyterliteTransport(BaseTransport):
Expand All @@ -188,11 +189,9 @@ def handle_request(self, request: Request) -> Response:
req.withCredentials = True

for name, value in request.headers.items():
if name.lower() == "host":
continue
req.setRequestHeader(name, value)

req.send(request.content)
req.send(to_js(request.read()))

xhr = XHRResponse(req)

Expand Down
4 changes: 3 additions & 1 deletion skore-hub-project/src/skore_hub_project/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from joblib import load as joblib_load
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from sklearn.utils.validation import _check_pos_label_consistency
from skore import THREADABLE, console

from skore_hub_project import console, switch_plt_backend
from skore_hub_project import switch_plt_backend
from skore_hub_project.client.client import Client, HUBClient
from skore_hub_project.exception import ForbiddenException, NotFoundException
from skore_hub_project.json import dumps
Expand Down Expand Up @@ -294,6 +295,7 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None
TextColumn("{task.description}"),
TimeElapsedColumn(),
console=console,
auto_refresh=THREADABLE,
) as progress,
):
task = progress.add_task(
Expand Down
19 changes: 8 additions & 11 deletions skore-hub-project/src/skore_hub_project/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import cached_property, partial
from operator import methodcaller
from typing import ClassVar, Generic, TypeVar, cast

from joblib import Parallel, delayed
from pydantic import BaseModel, ConfigDict, Field, computed_field
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn
from skore import THREADABLE, console

from skore_hub_project import console
from skore_hub_project.artifact.media.media import Media
from skore_hub_project.artifact.pickle import Pickle
from skore_hub_project.metric.metric import Metric
Expand All @@ -29,6 +30,7 @@
TimeElapsedColumn(),
console=console,
transient=True,
auto_refresh=THREADABLE,
)

Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport))
Expand Down Expand Up @@ -109,22 +111,17 @@ def metrics(self) -> list[Metric[Report]]:
self.report.cache_predictions()

metrics = [metric_cls(report=self.report) for metric_cls in self.METRICS]
tasks = list(map(delayed(methodcaller("compute")), metrics))

with SkinnedProgress() as progress, ThreadPoolExecutor() as pool:
tasks = [
pool.submit(lambda metric: metric.compute(), metric)
for metric in metrics
]

for task in progress.track(
as_completed(tasks),
with SkinnedProgress() as progress:
for _ in progress.track(
Parallel(backend="threading")(tasks),
description=(
f"Computing {self.report.__class__.__name__} "
f"#{self.report._hash} metrics"
),
total=len(tasks),
):
task.result()
progress.refresh()

return [metric for metric in metrics if metric.value is not None]
Expand Down
36 changes: 23 additions & 13 deletions skore/src/skore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from importlib.metadata import version
from logging import INFO, NullHandler, getLogger
from threading import Thread
from warnings import warn

from joblib import __version__ as joblib_version
Expand Down Expand Up @@ -37,6 +38,7 @@
from skore._sklearn._plot.inspection.permutation_importance import (
PermutationImportanceDisplay,
)
from skore._utils._environment import is_environment_notebook_like
from skore._utils._patch import setup_jupyter_display
from skore._utils._show_versions import show_versions

Expand All @@ -54,28 +56,31 @@
)


__version__ = version("skore")
__all__ = [
"Check",
"CoefficientsDisplay",
"DiagnosticDisplay",
"CheckNotApplicable",
"CoefficientsDisplay",
"ComparisonReport",
"compare",
"ConfusionMatrixDisplay",
"CrossValidationReport",
"DiagnosticDisplay",
"Display",
"EstimatorReport",
"evaluate",
"ImpurityDecreaseDisplay",
"TrainTestSplit",
"MetricsSummaryDisplay",
"PermutationImportanceDisplay",
"PrecisionRecallCurveDisplay",
"PredictionErrorDisplay",
"Project",
"RocCurveDisplay",
"THREADABLE",
"TableReportDisplay",
"TrainTestSplit",
"compare",
"configuration",
"console",
"evaluate",
"login",
"show_versions",
"train_test_split",
Expand All @@ -87,14 +92,19 @@
logger.setLevel(INFO)


skore_console_theme = Theme(
{
"repr.str": "cyan",
"rule.line": "orange1",
"repr.url": "orange1",
}
console = Console(
width=88,
theme=Theme({"repr.str": "cyan", "rule.line": "orange1", "repr.url": "orange1"}),
# ...
force_jupyter=(is_environment_notebook_like() or None),
)


console = Console(theme=skore_console_theme, width=88)
__version__ = version("skore")
try:
thread = Thread()
thread.start()
thread.join()
except Exception:
THREADABLE = False
else:
THREADABLE = True
15 changes: 10 additions & 5 deletions skore/src/skore/_utils/_environment.py
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be superseded by Textualize/rich#4104.

Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@ def get_environment_info() -> dict[str, Any]:
try:
# get_ipython() is defined when running in Jupyter or IPython
# there is no need to import IPython here
shell = get_ipython().__class__.__name__ # type: ignore
ipython = get_ipython() # type: ignore[name-defined]
except NameError:
pass
else:
shell = ipython.__class__.__name__

env_info["details"]["ipython_shell"] = shell

if shell == "ZMQInteractiveShell": # Jupyter notebook/lab
# Jupyter notebook/lab or Jupyterlite
if (shell == "ZMQInteractiveShell") or ("pyodide" in str(ipython.__class__)):
env_info["is_jupyter"] = True
env_info["environment_name"] = "jupyter"
elif shell == "TerminalInteractiveShell": # IPython terminal
# IPython terminal
elif shell == "TerminalInteractiveShell":
env_info["environment_name"] = "ipython_terminal"
except NameError:
pass

if "VSCODE_PID" in os.environ:
env_info["is_vscode"] = True
Expand Down
14 changes: 6 additions & 8 deletions skore/src/skore/_utils/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
from operator import length_hint
from typing import Any, TypeVar

from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
)

from skore._config import configuration
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn

T = TypeVar("T")

Expand All @@ -18,6 +11,9 @@ class ProgressBar:
"""Simplified progress bar based on ``rich.Progress``."""

def __init__(self, description: str, total: float | None):
from skore import THREADABLE, console
from skore._config import configuration

progress = Progress(
SpinnerColumn(),
TextColumn("[bold cyan]{task.description}"),
Expand All @@ -27,8 +23,10 @@ def __init__(self, description: str, total: float | None):
pulse_style="orange1",
),
TextColumn("[orange1]{task.percentage:>3.0f}%"),
console=console,
expand=False,
transient=True,
auto_refresh=THREADABLE,
disable=(not configuration.show_progress),
)

Expand Down
Loading