Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,8 @@ def run(
common parameters shown below will be added and can be passed through the `override` parameter of this method.

- ``"output_dir"``: the path to save mlflow tracking outputs locally, default to "<bundle root>/eval".
- ``"tracking_uri"``: uri to save mlflow tracking outputs, default to "/output_dir/mlruns".
- ``"tracking_uri"``: uri to save mlflow tracking outputs, default to a local SQLite database
at "<output_dir>/mlruns.db" with run artifacts kept under "<output_dir>/mlruns".
- ``"experiment_name"``: experiment name for this run, default to "monai_experiment".
- ``"run_name"``: the name of current run.
- ``"save_execute_config"``: whether to save the executed config files. It can be `False`, `/path/to/artifacts`
Expand Down
6 changes: 4 additions & 2 deletions monai/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@
"configs": {
# if no "output_dir" in the bundle config, default to "<bundle root>/eval"
"output_dir": "$@bundle_root + '/eval'",
# use URI to support linux, mac and windows os
"tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlruns'",
# MLflow 3.13+ rejects the filesystem (file store) tracking backend, so default tracking
# to a local SQLite database. The handler keeps run artifacts under "<output_dir>/mlruns"
# (next to the db). A URI is used so the path is valid on linux, mac and windows os.
"tracking_uri": "$monai.utils.path_to_sqlite_uri(@output_dir + '/mlruns.db')",
"experiment_name": "monai_experiment",
"run_name": None,
# may fill it at runtime
Expand Down
87 changes: 80 additions & 7 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@
from torch.utils.data import Dataset

from monai.apps.utils import get_logger
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import
from monai.utils import (
CommonKeys,
IgniteInfo,
ensure_tuple,
flatten_dict,
min_version,
optional_import,
path_to_sqlite_uri,
path_to_uri,
)

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
Expand Down Expand Up @@ -68,7 +77,14 @@ class MLFlowHandler:
tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment
variable to have MLflow find a URI from there. in both cases, the URI can either be
an HTTP/HTTPS URI for a remote server, a database connection string, or a local path
to log data to a directory. The URI defaults to path `mlruns`.
to log data to a directory. When no ``tracking_uri`` is provided and the
``MLFLOW_TRACKING_URI`` environment variable is unset, the handler now
defaults to a local SQLite database backend at ``sqlite:///<cwd>/mlruns.db`` with
artifacts stored under ``<cwd>/mlruns``. The default was changed from the filesystem
(file store) backend because MLflow 3.13+ raises an exception for the file store unless
``MLFLOW_ALLOW_FILE_STORE=true`` is set; SQLite is the backend MLflow recommends and it
does not raise. Any explicitly provided ``tracking_uri`` (including a local file path or
``file://`` URI) is passed through unchanged.
for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.
iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.
``iteration_log`` can be also a function and it will be interpreted as an event filter
Expand Down Expand Up @@ -113,6 +129,10 @@ class MLFlowHandler:
optimizer_param_names: parameter names in the optimizer that need to be recorded during running the
workflow, default to `'lr'`.
close_on_complete: whether to close the mlflow run in `complete` phase in workflow, default to False.
artifact_location: the location to store run artifacts in, passed to MLflow when the experiment is
created. When ``None`` and a local SQLite ``tracking_uri`` is used, it defaults to an
``mlruns`` directory next to the database file; for other backends ``None`` lets MLflow
decide based on the ``tracking_uri``. Has no effect if the experiment already exists.

For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html.

Expand Down Expand Up @@ -141,6 +161,7 @@ def __init__(
artifacts: str | Sequence[Path] | None = None,
optimizer_param_names: str | Sequence[str] = "lr",
close_on_complete: bool = False,
artifact_location: str | None = None,
) -> None:
self.iteration_log = iteration_log
self.epoch_log = epoch_log
Expand All @@ -156,6 +177,24 @@ def __init__(
self.experiment_param = experiment_param
self.artifacts = ensure_tuple(artifacts)
self.optimizer_param_names = ensure_tuple(optimizer_param_names)
# When no tracking_uri is provided, default to a local SQLite backend instead of the
# filesystem (file store) backend. MLflow 3.13+ raises for the file store unless
# `MLFLOW_ALLOW_FILE_STORE=true` is set, while SQLite is the recommended backend and does
# not raise. Artifacts cannot live inside a database, so by default they are stored under
# the `./mlruns` directory (where the previous file store default kept them) via the
# experiment `artifact_location`. Any explicitly provided tracking_uri is left unchanged.
self.artifact_location = artifact_location
# Only fall back to the SQLite default when the caller gave no tracking_uri and the
# `MLFLOW_TRACKING_URI` environment variable is unset, so that env-var configuration keeps
# working. When it is set, `tracking_uri` stays None and MLflow resolves the env var.
if not tracking_uri and not os.environ.get("MLFLOW_TRACKING_URI"):
tracking_uri = path_to_sqlite_uri(os.path.join(os.getcwd(), "mlruns.db"))
# For a local SQLite backend, keep run artifacts in an `mlruns` directory next to the
# database file (mirroring the previous file-store layout) unless the caller set
# `artifact_location`. Other backends (e.g. a remote server) are left to MLflow to decide.
if self.artifact_location is None and tracking_uri and tracking_uri.startswith("sqlite:///"):
db_path = Path(tracking_uri[len("sqlite:///") :])
self.artifact_location = path_to_uri(db_path.parent / "mlruns")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None)
self.run_finish_status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
self.close_on_complete = close_on_complete
Expand Down Expand Up @@ -245,7 +284,12 @@ def _set_experiment(self):
try:
experiment = self.client.get_experiment_by_name(self.experiment_name)
if not experiment:
experiment_id = self.client.create_experiment(self.experiment_name)
# pass an explicit artifact_location (set for the default SQLite backend, or
# by the caller) so artifacts land in the intended directory; when it is
# None MLflow decides based on the tracking_uri.
experiment_id = self.client.create_experiment(
self.experiment_name, artifact_location=self.artifact_location
)
experiment = self.client.get_experiment(experiment_id)
break
except MlflowException as e:
Expand Down Expand Up @@ -336,14 +380,43 @@ def complete(self) -> None:
for artifact in artifact_list:
self.client.log_artifact(self.cur_run.info.run_id, artifact)

def _dispose_sqlite_store(self) -> None:
"""
Release MLflow's SQLAlchemy engine when a local SQLite tracking backend is used.

MLflow keeps the SQLite connection open for the lifetime of the client, which on
Windows prevents the database file from being deleted. MLflow exposes no public
client close/dispose API, so this reaches into its internals defensively to release
the engine. It is a no-op for non-SQLite backends.
"""
tracking_uri = getattr(self.client, "tracking_uri", "")
if not isinstance(tracking_uri, str) or not tracking_uri.startswith("sqlite:"):
return
store = getattr(getattr(self.client, "_tracking_client", None), "store", None)
if store is None:
return
dispose = getattr(store, "_dispose_engine", None)
if callable(dispose):
dispose()
else:
engine = getattr(store, "engine", None)
if engine is not None:
engine.dispose()
read_engine = getattr(store, "read_engine", None)
if read_engine is not None:
read_engine.dispose()

def close(self) -> None:
"""
Stop current running logger of MLFlow.
Stop current running logger of MLFlow and release local SQLite resources.

"""
if self.cur_run:
self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status)
self.cur_run = None
try:
if self.cur_run:
self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status)
self.cur_run = None
finally:
self._dispose_sqlite_store()

def epoch_completed(self, engine: Engine) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
is_sqrt,
issequenceiterable,
list_to_dict,
path_to_sqlite_uri,
path_to_uri,
pprint_edges,
progress_bar,
Expand Down
16 changes: 16 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"save_obj",
"label_union",
"path_to_uri",
"path_to_sqlite_uri",
"pprint_edges",
"check_key_duplicates",
"CheckKeyDuplicatesYamlLoader",
Expand Down Expand Up @@ -727,6 +728,21 @@ def path_to_uri(path: PathLike) -> str:
return Path(path).absolute().as_uri()


def path_to_sqlite_uri(path: PathLike) -> str:
"""
Convert a database file path to a SQLite connection URI, e.g. for use as an MLflow
``tracking_uri``. If not an absolute path, it is converted to an absolute path first.

A forward-slash (POSIX) path is used so the URI is valid on Windows as well as POSIX:
on Windows this yields ``sqlite:///C:/path/db.sqlite`` and on POSIX ``sqlite:////path/db.sqlite``.

Args:
path: input database file path, can be a string or `Path` object.

"""
return f"sqlite:///{Path(path).absolute().as_posix()}"


def pprint_edges(val: Any, n_lines: int = 20) -> str:
"""
Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines.
Expand Down
18 changes: 11 additions & 7 deletions tests/fl/monai_algo/test_fl_monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.fl.client.monai_algo import MonaiAlgo
from monai.fl.utils.constants import ExtraItems
from monai.fl.utils.exchange_object import ExchangeObject
from monai.utils import path_to_uri
from monai.utils import path_to_sqlite_uri
from tests.test_utils import SkipIfNoModule

_root_dir = Path(__file__).resolve().parents[2]
Expand Down Expand Up @@ -79,7 +79,7 @@
"save_execute_config": f"{_data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_override",
"tracking_uri": path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_override.db")),
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
},
Expand All @@ -103,7 +103,7 @@
workflow_type="train",
logging_file=_logging_file,
tracking="mlflow",
tracking_uri=path_to_uri(_data_dir) + "/mlflow_1",
tracking_uri=path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_1.db")),
experiment_name="monai_eval1",
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
Expand All @@ -119,7 +119,7 @@
],
"eval_kwargs": {
"tracking": "mlflow",
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_2",
"tracking_uri": path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_2.db")),
"experiment_name": "monai_eval2",
},
"eval_workflow_name": "training",
Expand Down Expand Up @@ -202,8 +202,10 @@ def test_train(self, input_params):

# test experiment management
if "save_execute_config" in algo.train_workflow.parser:
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override"))
shutil.rmtree(f"{_data_dir}/mlflow_override")
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override.db"))
os.remove(f"{_data_dir}/mlflow_override.db")
if os.path.isdir(f"{_data_dir}/mlruns"):
shutil.rmtree(f"{_data_dir}/mlruns")
self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json"))
os.remove(f"{_data_dir}/config_executed.json")

Expand All @@ -227,7 +229,9 @@ def test_evaluate(self, input_params):
if "save_execute_config" in algo.eval_workflow.parser:
self.assertGreater(len(list(glob.glob(f"{_data_dir}/mlflow_*"))), 0)
for f in list(glob.glob(f"{_data_dir}/mlflow_*")):
shutil.rmtree(f)
shutil.rmtree(f) if os.path.isdir(f) else os.remove(f)
if os.path.isdir(f"{_data_dir}/mlruns"):
shutil.rmtree(f"{_data_dir}/mlruns")
self.assertGreater(len(list(glob.glob(f"{_data_dir}/eval/config_*"))), 0)
for f in list(glob.glob(f"{_data_dir}/eval/config_*")):
os.remove(f)
Expand Down
Loading
Loading