diff --git a/amlb/benchmark.py b/amlb/benchmark.py index 6dc326442..60009279f 100644 --- a/amlb/benchmark.py +++ b/amlb/benchmark.py @@ -65,6 +65,7 @@ class SetupMode(Enum): force = 2 only = 3 script = 4 + clean = 5 class Benchmark: diff --git a/amlb/resources.py b/amlb/resources.py index 8a37977c4..b2a93e35a 100644 --- a/amlb/resources.py +++ b/amlb/resources.py @@ -12,7 +12,7 @@ import random import re import sys -from functools import cache, cached_property +from functools import lru_cache, cached_property from amlb.benchmarks.parser import benchmark_load from amlb.frameworks import default_tag, load_framework_definitions @@ -171,7 +171,7 @@ def _frameworks(self): frameworks_file = self.config.frameworks.definition_file return load_framework_definitions(frameworks_file, self.config) - @cache + @lru_cache(maxsize=None) def constraint_definition(self, name: str) -> TaskConstraint: """ :param name: name of the benchmark constraint definition as defined in the constraints file diff --git a/amlb/runners/container.py b/amlb/runners/container.py index 94d283b9f..2793c015e 100644 --- a/amlb/runners/container.py +++ b/amlb/runners/container.py @@ -9,6 +9,7 @@ from abc import abstractmethod import logging +import os import re from typing import cast @@ -17,6 +18,7 @@ from ..frameworks.definitions import Framework from ..job import Job from ..resources import config as rconfig, get as rget +from ..utils import run_cmd from ..__version__ import __version__, _dev_version as dev @@ -78,6 +80,7 @@ def _validate(self): self.parallel_jobs = max_parallel_jobs def setup(self, mode, upload=False): + self.setup_mode = mode if mode == SetupMode.skip: return @@ -92,12 +95,31 @@ def setup(self, mode, upload=False): self._upload_image(self.image) def cleanup(self): - pass + if hasattr(self, "setup_mode") and self.setup_mode == SetupMode.clean: + if self.image: + log.info(f"Cleaning up docker image {self.image}.") + run_cmd(f"docker rmi -f {self.image}") + if hasattr(self, "_script") and os.path.exists(self._script): + log.info("Cleaning up generated script") + os.remove(self._script) + if hasattr(self, "task_defs"): + import openml + + for task_def in self.task_defs: + try: + if hasattr(task_def, "openml_task_id"): + openml.tasks.delete_task_cache(task_def.openml_task_id) + elif hasattr(task_def, "openml_dataset_id"): + openml.datasets.delete_dataset_cache( + task_def.openml_dataset_id + ) + except Exception as e: + log.warning(f"Failed to clean up OpenML cache: {e}") def run( self, tasks: str | list[str] | None = None, folds: int | list[int] | None = None ): - self._get_task_defs(tasks) # validates tasks + self.task_defs = self._get_task_defs(tasks) # validates tasks if self.parallel_jobs > 1 or not self.minimize_instances: return super().run(tasks, folds) else: diff --git a/amlb/utils/modules.py b/amlb/utils/modules.py index 83b018396..c2822ae6f 100644 --- a/amlb/utils/modules.py +++ b/amlb/utils/modules.py @@ -1,42 +1,26 @@ import logging import sys -import types - -try: - from pip._internal import main as pip_main -except ImportError: - from pip import main as pip_main +import subprocess log = logging.getLogger(__name__) __no_export = set(dir()) # all variables defined above this are not exported -def register_module(module_name): - if module_name not in sys.modules: - mod = types.ModuleType(module_name) - sys.modules[module_name] = mod - return sys.modules[module_name] - - -def register_submodule(mod, name): - fullname = ".".join([mod.__name__, name]) - module = register_module(fullname) - setattr(mod, name, module) - - def pip_install(module_or_requirements, is_requirements=False): try: + cmd = [sys.executable, "-m", "pip", "install", "--no-cache-dir"] if is_requirements: - pip_main(["install", "--no-cache-dir", "-r", module_or_requirements]) + cmd.extend(["-r", module_or_requirements]) else: - pip_main(["install", "--no-cache-dir", module_or_requirements]) - except SystemExit as se: + cmd.append(module_or_requirements) + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: log.error( "Error when trying to install python modules %s.", module_or_requirements ) - log.exception(se) + log.exception(e) __all__ = [s for s in dir() if not s.startswith("_") and s not in __no_export] diff --git a/runbenchmark.py b/runbenchmark.py index 431e865ef..15151d10f 100644 --- a/runbenchmark.py +++ b/runbenchmark.py @@ -127,7 +127,7 @@ parser.add_argument( "-s", "--setup", - choices=["auto", "skip", "force", "only"], + choices=["auto", "skip", "force", "only", "clean"], default="auto", help="Framework/platform setup mode. Available values are:" "\n• auto: setup is executed only if strictly necessary."