diff --git a/src/runpod_flash/core/resources/constants.py b/src/runpod_flash/core/resources/constants.py index 725a6d73..4ba8a9bc 100644 --- a/src/runpod_flash/core/resources/constants.py +++ b/src/runpod_flash/core/resources/constants.py @@ -119,6 +119,20 @@ def get_image_name( f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{FLASH_IMAGE_TAG}", ) +# Base images for process injection (no flash-worker baked in) +FLASH_GPU_BASE_IMAGE = os.environ.get( + "FLASH_GPU_BASE_IMAGE", "pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime" +) +FLASH_CPU_BASE_IMAGE = os.environ.get("FLASH_CPU_BASE_IMAGE", "python:3.11-slim") + +# Worker tarball for process injection +FLASH_WORKER_VERSION = os.environ.get("FLASH_WORKER_VERSION", "1.1.1") +FLASH_WORKER_TARBALL_URL_TEMPLATE = os.environ.get( + "FLASH_WORKER_TARBALL_URL", + "https://github.com/runpod-workers/flash/releases/download/" + "v{version}/flash-worker-v{version}-py3.11-linux-x86_64.tar.gz", +) + # Worker configuration defaults DEFAULT_WORKERS_MIN = 0 DEFAULT_WORKERS_MAX = 1 diff --git a/src/runpod_flash/core/resources/injection.py b/src/runpod_flash/core/resources/injection.py new file mode 100644 index 00000000..9ce72ec2 --- /dev/null +++ b/src/runpod_flash/core/resources/injection.py @@ -0,0 +1,54 @@ +"""Process injection utilities for flash-worker tarball delivery.""" + +from .constants import FLASH_WORKER_TARBALL_URL_TEMPLATE, FLASH_WORKER_VERSION + + +def build_injection_cmd( + worker_version: str = FLASH_WORKER_VERSION, + tarball_url: str | None = None, +) -> str: + """Build the dockerArgs command that downloads, extracts, and runs flash-worker. + + Supports remote URLs (curl/wget) and local file paths (file://) for testing. + Includes version-based caching to skip re-extraction on warm workers. + Network volume caching stores extracted tarball at /runpod-volume/.flash-worker/v{version}. + """ + if tarball_url is None: + tarball_url = FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=worker_version) + + if tarball_url.startswith("file://"): + local_path = tarball_url[7:] + return ( + "bash -c '" + "set -e; FW_DIR=/opt/flash-worker; " + "mkdir -p $FW_DIR; " + f"tar xzf {local_path} -C $FW_DIR --strip-components=1; " + "exec $FW_DIR/bootstrap.sh'" + ) + + return ( + "bash -c '" + f"set -e; FW_DIR=/opt/flash-worker; FW_VER={worker_version}; " + # Network volume cache check + 'NV_CACHE="/runpod-volume/.flash-worker/v$FW_VER"; ' + 'if [ -d "$NV_CACHE" ] && [ -f "$NV_CACHE/.version" ]; then ' + 'cp -r "$NV_CACHE" "$FW_DIR"; ' + # Local cache check (container disk persistence between restarts) + 'elif [ -f "$FW_DIR/.version" ] && [ "$(cat $FW_DIR/.version)" = "$FW_VER" ]; then ' + "true; " + "else " + "mkdir -p $FW_DIR; " + f'DL_URL="{tarball_url}"; ' + "dl() { " + '(command -v curl >/dev/null 2>&1 && curl -sSL "$1" || ' + 'command -v wget >/dev/null 2>&1 && wget -qO- "$1" || ' + 'python3 -c "import urllib.request,sys;sys.stdout.buffer.write(urllib.request.urlopen(sys.argv[1]).read())" "$1"); ' + "}; " + 'dl "$DL_URL" ' + "| tar xz -C $FW_DIR --strip-components=1; " + # Cache to network volume if available + "if [ -d /runpod-volume ]; then " + 'mkdir -p "$NV_CACHE" && cp -r "$FW_DIR"/* "$NV_CACHE/" 2>/dev/null || true; fi; ' + "fi; " + "exec $FW_DIR/bootstrap.sh'" + ) diff --git a/src/runpod_flash/core/resources/live_serverless.py b/src/runpod_flash/core/resources/live_serverless.py index 41dbe008..c9eb3f47 100644 --- a/src/runpod_flash/core/resources/live_serverless.py +++ b/src/runpod_flash/core/resources/live_serverless.py @@ -1,18 +1,22 @@ # Ship serverless code as you write it. No builds, no deploys -- just run. from typing import Any, ClassVar +# Ship serverless code as you write it. No builds, no deploys — just run. from pydantic import model_validator from .constants import ( - DEFAULT_PYTHON_VERSION, + GPU_BASE_IMAGE_PYTHON_VERSION, get_image_name, + local_python_version, ) +from .injection import build_injection_cmd from .load_balancer_sls_resource import ( CpuLoadBalancerSlsResource, LoadBalancerSlsResource, ) from .serverless import ServerlessEndpoint from .serverless_cpu import CpuServerlessEndpoint +from .template import PodTemplate class LiveServerlessMixin: @@ -27,6 +31,12 @@ class LiveServerlessMixin: concrete subclass (see ``_apply_default_live_image``); reads and writes of ``imageName`` go through the normal Pydantic field machinery so model serialization, drift detection, and ``setattr`` all stay consistent. + """Configures process injection via dockerArgs for any base image. + + Sets a default base image (user can override via imageName) and generates + dockerArgs to download, extract, and run the flash-worker tarball at container + start time. QB vs LB mode is determined by FLASH_ENDPOINT_TYPE env var at + runtime, not by the Docker image. """ _image_type: ClassVar[str] = ( @@ -53,50 +63,74 @@ def _apply_default_live_image(data: Any, image_type: str): data["imageName"] = get_image_name(image_type, python_version) return data + def _create_new_template(self) -> PodTemplate: + """Create template with dockerArgs for process injection.""" + template = super()._create_new_template() # type: ignore[misc] + template.dockerArgs = build_injection_cmd() + return template + + def _configure_existing_template(self) -> None: + """Configure existing template, adding dockerArgs for injection if not user-set.""" + super()._configure_existing_template() # type: ignore[misc] + if self.template is not None and not self.template.dockerArgs: # type: ignore[attr-defined] + self.template.dockerArgs = build_injection_cmd() # type: ignore[attr-defined] + class LiveServerless(LiveServerlessMixin, ServerlessEndpoint): """GPU-only live serverless endpoint.""" - _image_type: ClassVar[str] = "gpu" - @model_validator(mode="before") @classmethod def set_live_serverless_template(cls, data: dict): """Default to the GPU Flash runtime image when none is supplied.""" return _apply_default_live_image(data, "gpu") + """Set default GPU image for Live Serverless.""" + if "imageName" not in data: + python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION + data["imageName"] = get_image_name("gpu", python_version) + return data class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint): """CPU-only live serverless endpoint with automatic disk sizing.""" - _image_type: ClassVar[str] = "cpu" - @model_validator(mode="before") @classmethod def set_live_serverless_template(cls, data: dict): """Default to the CPU Flash runtime image when none is supplied.""" return _apply_default_live_image(data, "cpu") + """Set default CPU image for Live Serverless.""" + if "imageName" not in data: + python_version = data.get("python_version") or local_python_version() + data["imageName"] = get_image_name("cpu", python_version) + return data class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource): """Live load-balanced endpoint.""" - _image_type: ClassVar[str] = "lb" - @model_validator(mode="before") @classmethod def set_live_lb_template(cls, data: dict): """Default to the LB Flash runtime image when none is supplied.""" return _apply_default_live_image(data, "lb") + """Set default image for Live Load-Balanced endpoint.""" + if "imageName" not in data: + python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION + data["imageName"] = get_image_name("lb", python_version) + return data class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource): """CPU-only live load-balanced endpoint.""" - _image_type: ClassVar[str] = "lb-cpu" - @model_validator(mode="before") @classmethod def set_live_cpu_lb_template(cls, data: dict): """Default to the CPU LB Flash runtime image when none is supplied.""" return _apply_default_live_image(data, "lb-cpu") + """Set default CPU image for Live Load-Balanced endpoint.""" + if "imageName" not in data: + python_version = data.get("python_version") or local_python_version() + data["imageName"] = get_image_name("lb-cpu", python_version) + return data diff --git a/tests/integration/test_cpu_disk_sizing.py b/tests/integration/test_cpu_disk_sizing.py index d7070668..41eb31af 100644 --- a/tests/integration/test_cpu_disk_sizing.py +++ b/tests/integration/test_cpu_disk_sizing.py @@ -125,11 +125,11 @@ def test_live_serverless_cpu_integration(self): ) # Verify integration: - # 1. Uses CPU image (locked) + # 1. Uses CPU base image (default) # 2. CPU utilities calculate minimum disk size # 3. Template creation with auto-sizing # 4. Validation passes - assert "flash-cpu:" in live_serverless.imageName + assert "runpod/flash-cpu:" in live_serverless.imageName assert live_serverless.instanceIds == [ CpuInstanceType.CPU5C_1_2, CpuInstanceType.CPU5C_2_4, @@ -253,9 +253,18 @@ def test_live_serverless_image_consistency(self): cpu_live = CpuLiveServerless(name="cpu-live") # Verify different default images are used per resource type. +class TestLiveServerlessImageDefaultsIntegration: + """Test image defaults in live serverless variants.""" + + def test_live_serverless_image_defaults(self): + """Test that LiveServerless variants use correct base images.""" + gpu_live = LiveServerless(name="gpu-live") + cpu_live = CpuLiveServerless(name="cpu-live") + + # Verify different base images are used assert gpu_live.imageName != cpu_live.imageName - assert "flash:" in gpu_live.imageName - assert "flash-cpu:" in cpu_live.imageName + assert "runpod/flash:" in gpu_live.imageName + assert "runpod/flash-cpu:" in cpu_live.imageName def test_live_serverless_image_override_via_constructor(self): """Caller-supplied imageName overrides the Flash runtime default (AE-3153).""" @@ -266,6 +275,11 @@ def test_live_serverless_image_override_via_constructor(self): assert gpu_live.imageName == "custom/image:latest" assert cpu_live.imageName == "custom/cpu-image:latest" + # Verify images can be overridden (BYOI) + custom_gpu = LiveServerless( + name="custom-gpu", imageName="nvidia/cuda:12.8.0-runtime" + ) + assert custom_gpu.imageName == "nvidia/cuda:12.8.0-runtime" def test_live_serverless_template_integration(self): """Test live serverless template integration with disk sizing.""" diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 943c7830..1b8f836b 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -112,8 +112,9 @@ async def echo(message: str): return {"echo": message} # Verify resource is correctly configured - assert lb.name == "test-live-api" - assert "flash-lb" in lb.imageName + # Note: name may have "-fb" appended by flash boot validator + assert "test-live-api" in lb.name + assert "runpod/flash-lb:" in lb.imageName # GPU LB base image assert echo.__remote_config__["method"] == "POST" def test_live_load_balancer_image_default_and_override(self): @@ -131,6 +132,15 @@ def test_live_load_balancer_image_default_and_override(self): # Guard against a future regression where both paths collapse to the # same default (e.g. the override branch reverting to a no-op). assert default_lb.imageName != custom_lb.imageName + def test_live_load_balancer_default_image(self): + """Test that LiveLoadBalancer uses GPU LB base image by default.""" + lb = LiveLoadBalancer(name="test-api") + assert "runpod/flash-lb:" in lb.imageName + + def test_live_load_balancer_allows_custom_image(self): + """Test that LiveLoadBalancer allows user to set custom image (BYOI).""" + lb = LiveLoadBalancer(name="test-api", imageName="custom-image:latest") + assert lb.imageName == "custom-image:latest" def test_load_balancer_vs_queue_based_endpoints(self): """Test that LB and QB endpoints have different characteristics.""" @@ -186,7 +196,7 @@ def get_status(): with tempfile.TemporaryDirectory() as tmpdir: project_dir = Path(tmpdir) - py_file = project_dir / "api_worker.py" + py_file = project_dir / "test_api.py" py_file.write_text(code) scanner = RuntimeScanner(project_dir) @@ -201,7 +211,9 @@ def get_status(): assert "LoadBalancerSlsResource" in resource_types # Verify resource configs were extracted - assert "test-api" in scanner.resource_types - assert scanner.resource_types["test-api"] == "LiveLoadBalancer" - assert "deployed-api" in scanner.resource_types - assert scanner.resource_types["deployed-api"] == "LoadBalancerSlsResource" + assert "test-api-fb" in scanner.resource_types + assert scanner.resource_types["test-api-fb"] == "LiveLoadBalancer" + assert "deployed-api-fb" in scanner.resource_types + assert ( + scanner.resource_types["deployed-api-fb"] == "LoadBalancerSlsResource" + ) diff --git a/tests/unit/resources/test_injection.py b/tests/unit/resources/test_injection.py new file mode 100644 index 00000000..d8cf05de --- /dev/null +++ b/tests/unit/resources/test_injection.py @@ -0,0 +1,83 @@ +"""Unit tests for process injection utilities.""" + +from runpod_flash.core.resources.injection import build_injection_cmd + + +class TestBuildInjectionCmd: + """Test build_injection_cmd() output format.""" + + def test_default_remote_url(self): + """Test default remote URL generation.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + assert cmd.startswith("bash -c '") + assert "FW_VER=1.1.1" in cmd + assert "runpod-workers/flash/releases/download/v1.1.1/" in cmd + assert "bootstrap.sh'" in cmd + + def test_custom_tarball_url(self): + """Test custom tarball URL.""" + url = "https://example.com/worker.tar.gz" + cmd = build_injection_cmd(worker_version="2.0.0", tarball_url=url) + + assert "FW_VER=2.0.0" in cmd + assert url in cmd + + def test_file_url_for_local_testing(self): + """Test file:// URL generates local extraction command.""" + cmd = build_injection_cmd( + worker_version="1.0.0", + tarball_url="file:///tmp/flash-worker.tar.gz", + ) + + assert "tar xzf /tmp/flash-worker.tar.gz" in cmd + assert "curl" not in cmd + assert "wget" not in cmd + assert "bootstrap.sh'" in cmd + + def test_version_caching_logic(self): + """Test that version-based cache check is included.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + # Should check .version file + assert ".version" in cmd + assert "FW_VER" in cmd + + def test_network_volume_caching(self): + """Test network volume cache path is included.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + assert "/runpod-volume/.flash-worker/" in cmd + assert "NV_CACHE" in cmd + + def test_curl_wget_python_fallback(self): + """Test curl/wget/python3 fallback chain.""" + cmd = build_injection_cmd(worker_version="1.0.0") + + assert "curl -sSL" in cmd + assert "wget -qO-" in cmd + assert "urllib.request" in cmd + + def test_default_uses_constants(self): + """Test that calling with no args uses module-level constants.""" + from runpod_flash.core.resources.constants import FLASH_WORKER_VERSION + + cmd = build_injection_cmd() + + assert f"FW_VER={FLASH_WORKER_VERSION}" in cmd + assert f"v{FLASH_WORKER_VERSION}" in cmd + + def test_strip_components_in_remote_extraction(self): + """Test tar uses --strip-components=1 for remote downloads.""" + cmd = build_injection_cmd(worker_version="1.0.0") + + assert "--strip-components=1" in cmd + + def test_strip_components_in_local_extraction(self): + """Test tar uses --strip-components=1 for local file extraction.""" + cmd = build_injection_cmd( + worker_version="1.0.0", + tarball_url="file:///tmp/fw.tar.gz", + ) + + assert "--strip-components=1" in cmd diff --git a/tests/unit/resources/test_live_load_balancer.py b/tests/unit/resources/test_live_load_balancer.py index e459758a..d706cd94 100644 --- a/tests/unit/resources/test_live_load_balancer.py +++ b/tests/unit/resources/test_live_load_balancer.py @@ -1,11 +1,13 @@ -""" -Unit tests for LiveLoadBalancer class and template serialization. -""" +"""Unit tests for LiveLoadBalancer class and template serialization.""" +import importlib import os import pytest - +from runpod_flash.core.resources.constants import ( + GPU_BASE_IMAGE_PYTHON_VERSION, + local_python_version, +) from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.live_serverless import ( CpuLiveLoadBalancer, @@ -23,7 +25,6 @@ def test_live_load_balancer_creation_with_local_tag(self, monkeypatch): """Test LiveLoadBalancer creates with local image tag.""" monkeypatch.setenv("FLASH_IMAGE_TAG", "local") # Need to reload modules to pick up new env var - import importlib import runpod_flash.core.resources.constants as const_module import runpod_flash.core.resources.live_serverless as ls_module @@ -42,22 +43,30 @@ def test_live_load_balancer_default_image_tag(self): os.environ.pop("FLASH_IMAGE_TAG", None) lb = LiveLoadBalancer(name="test-lb") - - assert "runpod/flash-lb:" in lb.imageName + assert f"py{GPU_BASE_IMAGE_PYTHON_VERSION}" in lb.imageName assert lb.template is not None assert lb.template.imageName == lb.imageName + def test_live_load_balancer_user_can_override_image(self): + """Test user can set custom imageName (BYOI).""" + lb = LiveLoadBalancer(name="test-lb", imageName="custom/image:v1") + assert lb.imageName == "custom/image:v1" + def test_live_load_balancer_template_creation(self): """Test LiveLoadBalancer creates proper template from imageName.""" lb = LiveLoadBalancer(name="cpu_processor") - # Should have a template created from imageName assert lb.template is not None assert lb.template.imageName == lb.imageName - # Template name uses resource IDs, not the original name assert "LiveLoadBalancer" in lb.template.name assert "PodTemplate" in lb.template.name + def test_live_load_balancer_template_has_docker_args(self): + """Test LiveLoadBalancer template has process injection dockerArgs.""" + lb = LiveLoadBalancer(name="test-lb") + assert lb.template.dockerArgs + assert "bootstrap.sh" in lb.template.dockerArgs + def test_live_load_balancer_template_env_variables(self): """Test LiveLoadBalancer template includes environment variables.""" lb = LiveLoadBalancer( @@ -69,7 +78,6 @@ def test_live_load_balancer_template_env_variables(self): assert lb.template.env is not None assert len(lb.template.env) > 0 - # Check for custom env var custom_vars = [kv for kv in lb.template.env if kv.key == "CUSTOM_VAR"] assert len(custom_vars) == 1 assert custom_vars[0].value == "custom_value" @@ -78,14 +86,11 @@ def test_live_load_balancer_payload_serialization(self): """Test LiveLoadBalancer serializes correctly for GraphQL deployment.""" lb = LiveLoadBalancer(name="data_processor") - # Generate payload as would be sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Template must be in payload (not imageName since that's in _input_only) assert "template" in payload assert "imageName" not in payload - # Template must have all required fields template = payload["template"] assert "imageName" in template assert "name" in template @@ -94,14 +99,11 @@ def test_live_load_balancer_payload_serialization(self): def test_live_load_balancer_type_is_lb(self): """Test LiveLoadBalancer has type=LB.""" lb = LiveLoadBalancer(name="test-lb") - assert lb.type.value == "LB" - assert str(lb.type) == "ServerlessType.LB" def test_live_load_balancer_scaler_is_request_count(self): """Test LiveLoadBalancer uses REQUEST_COUNT scaler.""" lb = LiveLoadBalancer(name="test-lb") - assert lb.scalerType.value == "REQUEST_COUNT" @@ -147,21 +149,15 @@ def test_live_load_balancer_serialization_roundtrip(self): env={"API_KEY": "secret123"}, ) - # Simulate what gets sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Verify GraphQL payload has template assert "template" in payload, "Template must be in GraphQL payload" assert payload["template"]["imageName"] is not None assert payload["template"]["name"] is not None - - # Verify imageName is NOT in payload (it's in _input_only) assert "imageName" not in payload - # Verify the template has the correct image - assert "flash-lb:" in payload["template"]["imageName"], ( - "Must have load-balancer image" - ) + # dockerArgs must contain injection command + assert "bootstrap.sh" in payload["template"]["dockerArgs"] def test_template_env_serialization(self): """Test template environment variables serialize correctly.""" @@ -176,7 +172,6 @@ def test_template_env_serialization(self): assert isinstance(template_env, list) assert len(template_env) >= 2 - # Check env vars are serialized as {key, value} objects var_keys = {kv["key"] for kv in template_env} assert "VAR1" in var_keys assert "VAR2" in var_keys @@ -189,7 +184,6 @@ def test_cpu_live_load_balancer_creation_with_local_tag(self, monkeypatch): """Test CpuLiveLoadBalancer creates with local image tag.""" monkeypatch.setenv("FLASH_IMAGE_TAG", "local") # Need to reload modules to pick up new env var - import importlib import runpod_flash.core.resources.constants as const_module import runpod_flash.core.resources.live_serverless as ls_module @@ -208,16 +202,18 @@ def test_cpu_live_load_balancer_default_image_tag(self): os.environ.pop("FLASH_IMAGE_TAG", None) lb = CpuLiveLoadBalancer(name="test-lb") - - assert "runpod/flash-lb-cpu:" in lb.imageName + assert f"py{local_python_version()}" in lb.imageName assert lb.template is not None assert lb.template.imageName == lb.imageName + def test_cpu_live_load_balancer_user_can_override_image(self): + """Test CpuLiveLoadBalancer allows user image override.""" + lb = CpuLiveLoadBalancer(name="test-lb", imageName="python:3.11-slim") + assert lb.imageName == "python:3.11-slim" + def test_cpu_live_load_balancer_defaults(self): """Test CpuLiveLoadBalancer defaults to CPU3G_2_8.""" lb = CpuLiveLoadBalancer(name="test-lb") - - # Should default to CPU3G_2_8 assert lb.instanceIds == [CpuInstanceType.CPU3G_2_8] def test_cpu_live_load_balancer_with_specific_cpu_instances(self): @@ -226,34 +222,27 @@ def test_cpu_live_load_balancer_with_specific_cpu_instances(self): name="test-lb", instanceIds=[CpuInstanceType.CPU3G_1_4], ) - assert lb.instanceIds == [CpuInstanceType.CPU3G_1_4] def test_cpu_live_load_balancer_type_is_lb(self): """Test CpuLiveLoadBalancer has type=LB.""" lb = CpuLiveLoadBalancer(name="test-lb") - assert lb.type.value == "LB" - assert str(lb.type) == "ServerlessType.LB" def test_cpu_live_load_balancer_scaler_is_request_count(self): """Test CpuLiveLoadBalancer uses REQUEST_COUNT scaler.""" lb = CpuLiveLoadBalancer(name="test-lb") - assert lb.scalerType.value == "REQUEST_COUNT" def test_cpu_live_load_balancer_payload_serialization(self): """Test CpuLiveLoadBalancer serializes correctly for GraphQL deployment.""" lb = CpuLiveLoadBalancer(name="data_processor") - # Generate payload as would be sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Template must be in payload (not imageName since that's in _input_only) assert "template" in payload assert "imageName" not in payload - # Template must have all required fields template = payload["template"] assert "imageName" in template assert "name" in template @@ -265,7 +254,12 @@ def test_cpu_live_load_balancer_excludes_gpu_fields(self): payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # GPU-specific fields should not be in payload assert "gpus" not in payload assert "gpuIds" not in payload assert "cudaVersions" not in payload + + def test_cpu_live_load_balancer_template_has_docker_args(self): + """Test CpuLiveLoadBalancer template has process injection dockerArgs.""" + lb = CpuLiveLoadBalancer(name="test-lb") + assert lb.template.dockerArgs + assert "bootstrap.sh" in lb.template.dockerArgs diff --git a/tests/unit/resources/test_live_serverless.py b/tests/unit/resources/test_live_serverless.py index d1c4de0a..8bfcc036 100644 --- a/tests/unit/resources/test_live_serverless.py +++ b/tests/unit/resources/test_live_serverless.py @@ -1,6 +1,4 @@ -""" -Unit tests for LiveServerless and CpuLiveServerless classes. -""" +"""Unit tests for LiveServerless, CpuLiveServerless, and LiveServerlessMixin.""" import pytest from runpod_flash.core.resources.constants import ( @@ -8,10 +6,10 @@ ) from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.live_serverless import ( - LiveServerless, + CpuLiveLoadBalancer, CpuLiveServerless, LiveLoadBalancer, - CpuLiveLoadBalancer, + LiveServerless, ) from runpod_flash.core.resources.template import PodTemplate @@ -34,16 +32,12 @@ def test_live_serverless_idle_timeout_rejects_zero(self): LiveServerless(name="broken", idleTimeout=0) def test_live_serverless_gpu_defaults(self): - """Test LiveServerless uses GPU image and defaults.""" - live_serverless = LiveServerless( - name="example_gpu_live_serverless", - ) + """Test LiveServerless uses GPU base image and defaults.""" + live_serverless = LiveServerless(name="example_gpu_live_serverless") - # Should not have CPU instances, uses default 64GB assert live_serverless.instanceIds is None assert live_serverless.template is not None assert live_serverless.template.containerDiskInGb == 64 - assert "flash:" in live_serverless.imageName # GPU image def test_live_serverless_image_override_via_constructor(self): """LiveServerless accepts a caller-supplied imageName (AE-3153).""" @@ -59,6 +53,12 @@ def test_live_serverless_image_default_unchanged(self): """LiveServerless still defaults to the Flash GPU runtime image.""" live_serverless = LiveServerless(name="example_gpu_live_serverless") assert "flash:" in live_serverless.imageName + def test_live_serverless_user_can_override_image(self): + """Test user can set custom imageName (BYOI).""" + live_serverless = LiveServerless( + name="test", imageName="nvidia/cuda:12.8.0-runtime-ubuntu22.04" + ) + assert live_serverless.imageName == "nvidia/cuda:12.8.0-runtime-ubuntu22.04" def test_live_serverless_with_custom_template(self): """Test LiveServerless with custom template.""" @@ -67,31 +67,30 @@ def test_live_serverless_with_custom_template(self): imageName="test/image:v1", containerDiskInGb=100, ) - live_serverless = LiveServerless( name="example_gpu_live_serverless", template=template, ) - - # Should preserve custom template settings assert live_serverless.template.containerDiskInGb == 100 + def test_live_serverless_template_has_docker_args(self): + """Test that the template includes dockerArgs for process injection.""" + live_serverless = LiveServerless(name="test") + assert live_serverless.template is not None + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs + class TestCpuLiveServerless: """Test CpuLiveServerless class behavior.""" def test_cpu_live_serverless_defaults(self): """Test CpuLiveServerless uses CPU image and auto-sizing.""" - live_serverless = CpuLiveServerless( - name="example_cpu_live_serverless", - ) + live_serverless = CpuLiveServerless(name="example_cpu_live_serverless") - # Should default to CPU3G_2_8 assert live_serverless.instanceIds == [CpuInstanceType.CPU3G_2_8] assert live_serverless.template is not None - # Default disk size should be 20GB for CPU3G_2_8 assert live_serverless.template.containerDiskInGb == 20 - assert "flash-cpu:" in live_serverless.imageName # CPU image def test_cpu_live_serverless_custom_instances(self): """Test CpuLiveServerless with custom CPU instances.""" @@ -99,7 +98,6 @@ def test_cpu_live_serverless_custom_instances(self): name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], ) - assert live_serverless.instanceIds == [CpuInstanceType.CPU3G_1_4] assert live_serverless.template is not None assert live_serverless.template.containerDiskInGb == 10 @@ -110,9 +108,8 @@ def test_cpu_live_serverless_multiple_instances(self): name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4, CpuInstanceType.CPU5C_2_4], ) - assert live_serverless.template is not None - assert live_serverless.template.containerDiskInGb == 10 # Min of 10 and 30 + assert live_serverless.template.containerDiskInGb == 10 def test_cpu_live_serverless_image_override_via_constructor(self): """CpuLiveServerless accepts a caller-supplied imageName (AE-3153).""" @@ -131,15 +128,18 @@ def test_cpu_live_serverless_image_default_unchanged(self): instanceIds=[CpuInstanceType.CPU3G_1_4], ) assert "flash-cpu:" in live_serverless.imageName + def test_cpu_live_serverless_user_can_override_image(self): + """Test CpuLiveServerless allows user to set custom image.""" + live_serverless = CpuLiveServerless(name="test", imageName="python:3.11-slim") + assert live_serverless.imageName == "python:3.11-slim" def test_cpu_live_serverless_validation_failure(self): """Test CpuLiveServerless validation fails with excessive disk size.""" template = PodTemplate( name="custom", imageName="test/image:v1", - containerDiskInGb=50, # Exceeds 10GB limit + containerDiskInGb=50, ) - with pytest.raises(ValueError, match="Container disk size 50GB exceeds"): CpuLiveServerless( name="example_cpu_live_serverless", @@ -150,48 +150,53 @@ def test_cpu_live_serverless_validation_failure(self): def test_cpu_live_serverless_with_existing_template_default_size(self): """Test CpuLiveServerless auto-sizes existing template with default disk size.""" template = PodTemplate(name="existing", imageName="test/image:v1") - # Template uses default size - live_serverless = CpuLiveServerless( name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], template=template, ) - - assert live_serverless.template.containerDiskInGb == 10 # Should be auto-sized + assert live_serverless.template.containerDiskInGb == 10 def test_cpu_live_serverless_preserves_custom_disk_size(self): """Test CpuLiveServerless preserves custom disk size in template.""" template = PodTemplate( name="existing", imageName="test/image:v1", - containerDiskInGb=5, # Custom size within limits + containerDiskInGb=5, ) - live_serverless = CpuLiveServerless( name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], template=template, ) + assert live_serverless.template.containerDiskInGb == 5 - assert ( - live_serverless.template.containerDiskInGb == 5 - ) # Should preserve custom size + def test_cpu_live_serverless_template_has_docker_args(self): + """Test CpuLiveServerless template includes dockerArgs.""" + live_serverless = CpuLiveServerless(name="test") + assert live_serverless.template is not None + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs class TestLiveServerlessMixin: """Test LiveServerlessMixin functionality.""" - def test_live_image_property_gpu(self): - """Test LiveServerless _live_image property.""" + def test_docker_args_set_on_new_template(self): + """Test dockerArgs is set when creating a new template.""" live_serverless = LiveServerless(name="test") - assert "flash:" in live_serverless._live_image - assert "cpu" not in live_serverless._live_image + assert live_serverless.template.dockerArgs + assert "bash -c" in live_serverless.template.dockerArgs - def test_live_image_property_cpu(self): - """Test CpuLiveServerless _live_image property.""" - live_serverless = CpuLiveServerless(name="test") - assert "flash-cpu:" in live_serverless._live_image + def test_docker_args_set_on_existing_template(self): + """Test dockerArgs is set when configuring an existing template.""" + template = PodTemplate( + name="existing", + imageName="test/image:v1", + ) + live_serverless = LiveServerless(name="test", template=template) + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs def test_image_name_property_gpu(self): """LiveServerless defaults imageName to the Flash runtime image when none supplied.""" @@ -202,6 +207,35 @@ def test_image_name_property_cpu(self): """CpuLiveServerless defaults imageName to the Flash runtime image when none supplied.""" live_serverless = CpuLiveServerless(name="test") assert live_serverless.imageName == live_serverless._live_image + def test_all_live_classes_have_docker_args(self): + """Test all Live* classes set dockerArgs on their templates.""" + classes_and_kwargs = [ + (LiveServerless, {}), + (CpuLiveServerless, {}), + (LiveLoadBalancer, {}), + (CpuLiveLoadBalancer, {}), + ] + for cls, extra_kwargs in classes_and_kwargs: + resource = cls(name=f"test-{cls.__name__}", **extra_kwargs) + assert resource.template is not None, f"{cls.__name__} has no template" + assert resource.template.dockerArgs, f"{cls.__name__} has no dockerArgs" + assert "bootstrap.sh" in resource.template.dockerArgs, ( + f"{cls.__name__} missing bootstrap.sh in dockerArgs" + ) + + def test_live_load_balancer_defaults(self): + """Test LiveLoadBalancer uses GPU image.""" + lb = LiveLoadBalancer(name="test-lb") + assert lb.imageName is not None + assert lb.template is not None + assert lb.template.dockerArgs + + def test_cpu_live_load_balancer_defaults(self): + """Test CpuLiveLoadBalancer uses CPU image.""" + lb = CpuLiveLoadBalancer(name="test-lb-cpu") + assert lb.imageName is not None + assert lb.template is not None + assert lb.template.dockerArgs def test_image_name_override_gpu(self): """LiveServerless honors caller-supplied imageName (AE-3153).""" @@ -228,6 +262,15 @@ def test_default_image_validator_passes_through_non_dict(self): original = LiveServerless(name="test", imageName="byo/image:v1") revalidated = LiveServerless.model_validate(original) assert revalidated.imageName == "byo/image:v1" + def test_live_serverless_byoi_gpu(self): + """Test LiveServerless respects user-provided imageName.""" + live_serverless = LiveServerless(name="test", imageName="custom/gpu:v1") + assert live_serverless.imageName == "custom/gpu:v1" + + def test_live_serverless_byoi_cpu(self): + """Test CpuLiveServerless respects user-provided imageName.""" + live_serverless = CpuLiveServerless(name="test", imageName="custom/cpu:v1") + assert live_serverless.imageName == "custom/cpu:v1" class TestLiveServerlessPythonVersion: