Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions skore-hub-project/src/skore_hub_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from logging import basicConfig, getLogger
from threading import Thread

from matplotlib import pyplot as plt
from rich.console import Console
Expand Down Expand Up @@ -64,3 +65,17 @@ def switch_plt_backend(backend: str = "agg") -> Iterator[None]:
finally:
plt.close("all")
plt.switch_backend(original)


def threadable() -> bool:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering that instead to make fail a thread to check the platform with something like

def _can_start_thread() -> bool:
    if sys.platform == "emscripten":
        return sys._emscripten_info.pthreads
    return platform.machine() not in ("wasm32", "wasm64")

try:
thread = Thread()
thread.start()
thread.join()
except Exception:
return False
else:
return True


THREADABLE = threadable()
56 changes: 27 additions & 29 deletions skore-hub-project/src/skore_hub_project/artifact/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, as_completed
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING

from joblib import Parallel, delayed

from ..client.client import Client, HUBClient
from .serializer import Serializer

Expand Down Expand Up @@ -99,7 +100,6 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
Serializer(content) as serializer,
HUBClient() as hub_client,
Client() as standard_client,
ThreadPoolExecutor() as pool,
):
# Ask for upload urls.
response = hub_client.post(
Expand All @@ -116,7 +116,8 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
# An empty response means that an artifact with the same checksum already
# exists. The content doesn't have to be re-uploaded.
if urls := response.json():
task_to_chunk_id = {}
chunk_ids = []
tasks = []

# Upload each chunk of the serialized content to the artifacts storage,
# using a disk temporary file.
Expand All @@ -126,38 +127,35 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str:
#
# Use `threading` over `asyncio` to ensure compatibility with Jupyter
# notebooks, where the event loop is already running.

for url in urls:
chunk_id = url["chunk_id"] or 1
task = pool.submit(
upload_chunk,
filepath=serializer.filepath,
client=standard_client,
url=url["upload_url"],
offset=((chunk_id - 1) * CHUNK_SIZE),
length=CHUNK_SIZE,
content_type=(
content_type if len(urls) == 1 else "application/octet-stream"
),
)

task_to_chunk_id[task] = chunk_id

try:
etags = dict(
sorted(
(
task_to_chunk_id[task],
task.result(),
)
for task in as_completed(task_to_chunk_id)
chunk_ids.append(chunk_id)
tasks.append(
delayed(upload_chunk)(
filepath=serializer.filepath,
client=standard_client,
url=url["upload_url"],
offset=((chunk_id - 1) * CHUNK_SIZE),
length=CHUNK_SIZE,
content_type=(
content_type
if len(urls) == 1
else "application/octet-stream"
),
)
)
except BaseException:
# Cancel all remaining tasks, especially on `KeyboardInterrupt`.
for task in task_to_chunk_id:
task.cancel()

raise
etags = dict(
sorted(
zip(
chunk_ids,
Parallel(backend="threading")(tasks),
strict=True,
)
)
)

# Acknowledge the upload, to let the hub/storage rebuild the whole.
hub_client.post(
Expand Down
5 changes: 2 additions & 3 deletions skore-hub-project/src/skore_hub_project/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
if JUPYTERLITE and (transport is None):
from httpx import Request
from js import XMLHttpRequest
from pyodide.ffi import to_js
from pyodide.http.pyxhr import XHRResponse

class JupyterliteTransport(BaseTransport):
Expand All @@ -188,11 +189,9 @@ def handle_request(self, request: Request) -> Response:
req.withCredentials = True

for name, value in request.headers.items():
if name.lower() == "host":
continue
req.setRequestHeader(name, value)

req.send(request.content)
req.send(to_js(request.read()))

xhr = XHRResponse(req)

Expand Down
3 changes: 2 additions & 1 deletion skore-hub-project/src/skore_hub_project/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from sklearn.utils.validation import _check_pos_label_consistency

from skore_hub_project import console, switch_plt_backend
from skore_hub_project import THREADABLE, console, switch_plt_backend
from skore_hub_project.client.client import Client, HUBClient
from skore_hub_project.exception import ForbiddenException, NotFoundException
from skore_hub_project.json import dumps
Expand Down Expand Up @@ -294,6 +294,7 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None
TextColumn("{task.description}"),
TimeElapsedColumn(),
console=console,
auto_refresh=THREADABLE,
) as progress,
):
task = progress.add_task(
Expand Down
19 changes: 8 additions & 11 deletions skore-hub-project/src/skore_hub_project/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import cached_property, partial
from operator import methodcaller
from typing import ClassVar, Generic, TypeVar, cast

from joblib import Parallel, delayed
from pydantic import BaseModel, ConfigDict, Field, computed_field
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn

from skore_hub_project import console
from skore_hub_project import THREADABLE, console
from skore_hub_project.artifact.media.media import Media
from skore_hub_project.artifact.pickle import Pickle
from skore_hub_project.metric.metric import Metric
Expand All @@ -29,6 +30,7 @@
TimeElapsedColumn(),
console=console,
transient=True,
auto_refresh=THREADABLE,
)

Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport))
Expand Down Expand Up @@ -109,22 +111,17 @@ def metrics(self) -> list[Metric[Report]]:
self.report.cache_predictions()

metrics = [metric_cls(report=self.report) for metric_cls in self.METRICS]
tasks = list(map(delayed(methodcaller("compute")), metrics))

with SkinnedProgress() as progress, ThreadPoolExecutor() as pool:
tasks = [
pool.submit(lambda metric: metric.compute(), metric)
for metric in metrics
]

for task in progress.track(
as_completed(tasks),
with SkinnedProgress() as progress:
for _ in progress.track(
Parallel(backend="threading")(tasks),
description=(
f"Computing {self.report.__class__.__name__} "
f"#{self.report._hash} metrics"
),
total=len(tasks),
):
task.result()
progress.refresh()

return [metric for metric in metrics if metric.value is not None]
Expand Down
16 changes: 16 additions & 0 deletions skore/src/skore/_utils/_progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterable
from operator import length_hint
from threading import Thread
from typing import Any, TypeVar

from rich.progress import (
Expand All @@ -14,6 +15,20 @@
T = TypeVar("T")


def threadable() -> bool:
try:
thread = Thread()
thread.start()
thread.join()
except Exception:
return False
else:
return True


THREADABLE = threadable()


class ProgressBar:
"""Simplified progress bar based on ``rich.Progress``."""

Expand All @@ -29,6 +44,7 @@ def __init__(self, description: str, total: float | None):
TextColumn("[orange1]{task.percentage:>3.0f}%"),
expand=False,
transient=True,
auto_refresh=THREADABLE,
disable=(not configuration.show_progress),
)

Expand Down
Loading