diff --git a/docs/source/reference/modules_inference_server.rst b/docs/source/reference/modules_inference_server.rst index 7f67eed7b08..60cbd0383da 100644 --- a/docs/source/reference/modules_inference_server.rst +++ b/docs/source/reference/modules_inference_server.rst @@ -17,6 +17,8 @@ Core API :template: rl_template_noinherit.rst InferenceServer + InferenceServerConfig + InferenceDeviceConfig ProcessInferenceServer InferenceClient InferenceTransport diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 774edd15c89..c7e2160b574 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -19,7 +19,9 @@ from torchrl.modules.inference_server import ( InferenceClient, + InferenceDeviceConfig, InferenceServer, + InferenceServerConfig, InferenceTransport, MPTransport, ProcessInferenceServer, @@ -257,6 +259,23 @@ def test_stats_accounting(self): assert stats["avg_batch_size"] > 0 assert stats["p95_forward_ms"] >= 0 + def test_structured_config(self): + transport = ThreadingTransport() + policy = _make_policy() + server_config = InferenceServerConfig(max_batch_size=2, timeout=0.001) + device_config = InferenceDeviceConfig(policy_device="cpu", output_device="cpu") + with InferenceServer( + policy, + transport, + server_config=server_config, + device_config=device_config, + ) as server: + client = transport.client() + result = client(TensorDict({"observation": torch.randn(4)})) + stats = server.stats() + assert result["action"].device.type == "cpu" + assert stats["requests"] == 1 + @pytest.mark.gpu @pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") def test_cuda_policy_cpu_output(self): @@ -1004,6 +1023,30 @@ def test_process_server_backend_smoke(self): collector.shutdown() assert total >= 20 + def test_device_config_and_server_config(self): + """Collector accepts structured device and server config objects.""" + collector = AsyncBatchedCollector( + create_env_fn=[_counting_env_factory] * 2, + policy=_make_counting_policy(), + frames_per_batch=10, + total_frames=20, + server_config=InferenceServerConfig(max_batch_size=2), + device_config=InferenceDeviceConfig( + policy_device="cpu", + output_device="cpu", + env_device="cpu", + storing_device="cpu", + ), + ) + total = 0 + for batch in collector: + assert batch.device is None or batch.device.type == "cpu" + total += batch.numel() + stats = collector.server_stats() + collector.shutdown() + assert total >= 20 + assert stats["requests"] > 0 + # ============================================================================= # Tests: SlotTransport diff --git a/torchrl/collectors/_async_batched.py b/torchrl/collectors/_async_batched.py index 72d39c52f7f..e38d8fc7ba6 100644 --- a/torchrl/collectors/_async_batched.py +++ b/torchrl/collectors/_async_batched.py @@ -17,7 +17,9 @@ from torchrl.collectors._base import BaseCollector from torchrl.envs import AsyncEnvPool, EnvBase from torchrl.modules.inference_server import ( + InferenceDeviceConfig, InferenceServer, + InferenceServerConfig, ProcessInferenceServer, ThreadingTransport, ) @@ -72,15 +74,17 @@ def _env_loop( client: Callable | None, result_queue: queue.Queue, shutdown_event: threading.Event, + env_device: torch.device | None, + storing_device: torch.device | None, ): - """Per-env worker thread using pool slot for env execution and InferenceServer for policy. + """Per-env worker thread using a pool slot and inference-server policy. Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and one inference client. The pool handles the actual environment execution in whatever backend it was configured with (threading, multiprocessing, etc.), while this thread coordinates the send/recv cycle and inference submission. - reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ... + reset -> infer -> step_send -> step_recv -> put transition -> infer -> ... """ if client is None: client = transport.client() @@ -91,9 +95,13 @@ def _env_loop( action_td = client(obs) while not shutdown_event.is_set(): + if env_device is not None: + action_td = action_td.to(env_device) pool.async_step_and_maybe_reset_send(action_td, env_index=env_id) cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id) cur_td.set(_ENV_IDX_KEY, env_id) + if storing_device is not None: + cur_td = cur_td.to(storing_device) result_queue.put(cur_td) if shutdown_event.is_set(): break @@ -104,7 +112,11 @@ def _env_loop( class AsyncBatchedCollector(BaseCollector): - """Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`. + """Asynchronous collector with env slots and a policy server. + + The collector pairs per-env coordinator threads with an + :class:`~torchrl.envs.AsyncEnvPool` and an + :class:`~torchrl.modules.InferenceServer`. Unlike :class:`~torchrl.collectors.Collector`, this collector fully decouples environment stepping from policy inference: @@ -166,6 +178,18 @@ class AsyncBatchedCollector(BaseCollector): output_device (torch.device or str, optional): device where action TensorDicts are moved before being sent back to env workers. Defaults to ``None``. + env_device (torch.device or str, optional): device used by env workers + for action TensorDicts before stepping. Defaults to ``None``. + storing_device (torch.device or str, optional): device used for + collected transitions yielded by this collector. Defaults to + ``None``. + server_config (InferenceServerConfig, optional): structured server + batching and stats configuration. Mutually exclusive with + non-default batching keyword arguments. + device_config (InferenceDeviceConfig, optional): structured device + placement configuration. Mutually exclusive with ``device``, + ``policy_device``, ``output_device``, ``env_device``, and + ``storing_device``. backend (str, optional): global default backend for both environments and policy inference. Specific overrides ``env_backend`` and ``policy_backend`` take precedence when set. @@ -258,6 +282,10 @@ def __init__( create_env_kwargs: dict | list[dict] | None = None, policy_device: torch.device | str | None = None, output_device: torch.device | str | None = None, + env_device: torch.device | str | None = None, + storing_device: torch.device | str | None = None, + server_config: InferenceServerConfig | None = None, + device_config: InferenceDeviceConfig | None = None, server_backend: Literal["thread", "process"] = "thread", ): if policy is not None and policy_factory is not None: @@ -269,6 +297,40 @@ def __init__( "server_backend='process' requires policy_factory so the policy " "can be constructed inside the server process." ) + if server_config is not None: + if (max_batch_size, min_batch_size, server_timeout) != (64, 1, 0.01): + raise ValueError( + "server_config is mutually exclusive with non-default " + "batching keyword arguments." + ) + max_batch_size = server_config.max_batch_size + min_batch_size = server_config.min_batch_size + server_timeout = server_config.timeout + if device_config is not None: + if ( + device is not None + or policy_device is not None + or output_device is not None + or env_device is not None + or storing_device is not None + ): + raise ValueError( + "device_config is mutually exclusive with device, " + "policy_device, output_device, env_device, and " + "storing_device." + ) + policy_device = device_config.policy_device + output_device = device_config.server_output_device() + env_device = device_config.env_device + storing_device = device_config.storing_device + else: + policy_device = device if policy_device is None else policy_device + if output_device is None and env_device is not None: + output_device = env_device + self._env_device = torch.device(env_device) if env_device is not None else None + self._storing_device = ( + torch.device(storing_device) if storing_device is not None else None + ) # ---- resolve policy --------------------------------------------------- self._policy_factory = policy_factory @@ -312,7 +374,6 @@ def __init__( self._transport = transport # ---- build inference server ------------------------------------------- - policy_device = device if policy_device is None else policy_device if server_backend == "process": self._server = ProcessInferenceServer( policy_factory=policy_factory, @@ -324,6 +385,12 @@ def __init__( output_device=output_device, weight_sync=weight_sync, weight_sync_model_id=weight_sync_model_id, + collect_stats=( + True if server_config is None else server_config.collect_stats + ), + stats_window_size=( + 1024 if server_config is None else server_config.stats_window_size + ), ) else: self._server = InferenceServer( @@ -336,6 +403,12 @@ def __init__( output_device=output_device, weight_sync=weight_sync, weight_sync_model_id=weight_sync_model_id, + collect_stats=( + True if server_config is None else server_config.collect_stats + ), + stats_window_size=( + 1024 if server_config is None else server_config.stats_window_size + ), ) # ---- collector settings ----------------------------------------------- @@ -404,6 +477,8 @@ def _ensure_started(self) -> None: "client": self._clients[i], "result_queue": self._result_queue, "shutdown_event": self._shutdown_event, + "env_device": self._env_device, + "storing_device": self._storing_device, }, daemon=True, name=f"AsyncBatchedCollector-env-{i}", diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py index 213e30e1914..f543ceaaffd 100644 --- a/torchrl/modules/inference_server/__init__.py +++ b/torchrl/modules/inference_server/__init__.py @@ -3,6 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from torchrl.modules.inference_server._config import ( + InferenceDeviceConfig, + InferenceServerConfig, +) from torchrl.modules.inference_server._monarch import MonarchTransport from torchrl.modules.inference_server._mp import MPTransport from torchrl.modules.inference_server._ray import RayTransport @@ -17,7 +21,9 @@ __all__ = [ "InferenceClient", + "InferenceDeviceConfig", "InferenceServer", + "InferenceServerConfig", "InferenceTransport", "MonarchTransport", "MPTransport", diff --git a/torchrl/modules/inference_server/_config.py b/torchrl/modules/inference_server/_config.py new file mode 100644 index 00000000000..398c7d702ca --- /dev/null +++ b/torchrl/modules/inference_server/_config.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +def _as_device(device: torch.device | str | None) -> torch.device | None: + if device is None: + return None + return torch.device(device) + + +@dataclass +class InferenceDeviceConfig: + """Device placement for asynchronous policy-server collection. + + This config separates the devices used by the environment, the remote + policy, the actor-side action TensorDict, and the returned collector batch. + + Args: + policy_device (torch.device or str, optional): device that owns the + policy and receives batched server inputs. + output_device (torch.device or str, optional): device for inference + results returned by the server. + env_device (torch.device or str, optional): device used by env workers + when stepping environments. If ``output_device`` is omitted, this is + the natural device for returned actions. + storing_device (torch.device or str, optional): device used for + collected transitions yielded by the collector. + + Examples: + >>> import torch + >>> import torch.nn as nn + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.inference_server import ( + ... InferenceDeviceConfig, + ... InferenceServer, + ... ThreadingTransport, + ... ) + >>> policy = TensorDictModule( + ... nn.Linear(4, 2), in_keys=["observation"], out_keys=["action"] + ... ) + >>> transport = ThreadingTransport() + >>> device_config = InferenceDeviceConfig( + ... policy_device="cpu", output_device="cpu" + ... ) + >>> with InferenceServer(policy, transport, device_config=device_config): + ... client = transport.client() + ... result = client(TensorDict({"observation": torch.randn(4)})) + >>> result["action"].device.type + 'cpu' + """ + + policy_device: torch.device | str | None = None + output_device: torch.device | str | None = None + env_device: torch.device | str | None = None + storing_device: torch.device | str | None = None + + def __post_init__(self) -> None: + self.policy_device = _as_device(self.policy_device) + self.output_device = _as_device(self.output_device) + self.env_device = _as_device(self.env_device) + self.storing_device = _as_device(self.storing_device) + + def server_output_device(self) -> torch.device | None: + """Return the actor-side device expected from the policy server.""" + if self.output_device is not None: + return self.output_device + return self.env_device + + +@dataclass +class InferenceServerConfig: + """Server-side batching, timeout, and instrumentation settings. + + Args: + max_batch_size (int, optional): maximum number of requests per forward + pass. Defaults to ``64``. + min_batch_size (int, optional): minimum number of requests to + accumulate after the first request arrives. Defaults to ``1``. + timeout (float, optional): seconds to wait for more requests before + flushing a partial batch. Defaults to ``0.01``. + collect_stats (bool, optional): whether to collect lightweight + throughput and latency stats. Defaults to ``True``. + stats_window_size (int, optional): number of recent timing samples kept + for percentile stats. Defaults to ``1024``. + + Examples: + >>> from torchrl.modules.inference_server import InferenceServerConfig + >>> config = InferenceServerConfig(max_batch_size=8, timeout=0.001) + >>> config.max_batch_size + 8 + """ + + max_batch_size: int = 64 + min_batch_size: int = 1 + timeout: float = 0.01 + collect_stats: bool = True + stats_window_size: int = 1024 diff --git a/torchrl/modules/inference_server/_server.py b/torchrl/modules/inference_server/_server.py index 0fed93fec5e..2a48b89fcbb 100644 --- a/torchrl/modules/inference_server/_server.py +++ b/torchrl/modules/inference_server/_server.py @@ -18,6 +18,10 @@ from tensordict.base import TensorDictBase from torch import nn +from torchrl.modules.inference_server._config import ( + InferenceDeviceConfig, + InferenceServerConfig, +) from torchrl.modules.inference_server._transport import InferenceTransport @@ -68,6 +72,12 @@ class InferenceServer: weight_sync_model_id (str, optional): the model identifier used when initialising the weight sync scheme on the receiver side. Default: ``"policy"``. + server_config (InferenceServerConfig, optional): structured server + configuration. Mutually exclusive with non-default batching and + stats keyword arguments. + device_config (InferenceDeviceConfig, optional): structured device + placement configuration. Mutually exclusive with ``device``, + ``policy_device``, and ``output_device``. Example: >>> from tensordict.nn import TensorDictModule @@ -103,8 +113,39 @@ def __init__( stats_window_size: int = 1024, weight_sync=None, weight_sync_model_id: str = "policy", + server_config: InferenceServerConfig | None = None, + device_config: InferenceDeviceConfig | None = None, shutdown_event: threading.Event | MPEvent | None = None, ): + if server_config is not None: + if ( + max_batch_size, + min_batch_size, + timeout, + collect_stats, + stats_window_size, + ) != (64, 1, 0.01, True, 1024): + raise ValueError( + "server_config is mutually exclusive with non-default " + "batching and stats keyword arguments." + ) + max_batch_size = server_config.max_batch_size + min_batch_size = server_config.min_batch_size + timeout = server_config.timeout + collect_stats = server_config.collect_stats + stats_window_size = server_config.stats_window_size + if device_config is not None: + if ( + device is not None + or policy_device is not None + or output_device is not None + ): + raise ValueError( + "device_config is mutually exclusive with device, " + "policy_device, and output_device." + ) + policy_device = device_config.policy_device + output_device = device_config.server_output_device() self.model = model self.transport = transport self.max_batch_size = max_batch_size @@ -395,9 +436,10 @@ class ProcessInferenceServer: """Dedicated-process wrapper around :class:`InferenceServer`. This server is intended for actor/env workers that communicate through a - queue-based transport such as :class:`~torchrl.modules.inference_server.MPTransport`. - Clients must be created from the transport before :meth:`start` so that the - child process inherits their response queues. + queue-based transport such as + :class:`~torchrl.modules.inference_server.MPTransport`. Clients must be + created from the transport before :meth:`start` so that the child process + inherits their response queues. Args: policy_factory (Callable[[], nn.Module]): picklable factory that creates @@ -417,6 +459,12 @@ class ProcessInferenceServer: stats_window_size (int, optional): forwarded to :class:`InferenceServer`. weight_sync: optional weight synchronization scheme. weight_sync_model_id (str, optional): model id for weight sync. + server_config (InferenceServerConfig, optional): structured server + configuration. Mutually exclusive with non-default batching and + stats keyword arguments. + device_config (InferenceDeviceConfig, optional): structured device + placement configuration. Mutually exclusive with ``device``, + ``policy_device``, and ``output_device``. mp_context: multiprocessing context or start-method name. Defaults to ``"spawn"``. @@ -457,8 +505,39 @@ def __init__( stats_window_size: int = 1024, weight_sync=None, weight_sync_model_id: str = "policy", + server_config: InferenceServerConfig | None = None, + device_config: InferenceDeviceConfig | None = None, mp_context: str | mp.context.BaseContext | None = None, ) -> None: + if server_config is not None: + if ( + max_batch_size, + min_batch_size, + timeout, + collect_stats, + stats_window_size, + ) != (64, 1, 0.01, True, 1024): + raise ValueError( + "server_config is mutually exclusive with non-default " + "batching and stats keyword arguments." + ) + max_batch_size = server_config.max_batch_size + min_batch_size = server_config.min_batch_size + timeout = server_config.timeout + collect_stats = server_config.collect_stats + stats_window_size = server_config.stats_window_size + if device_config is not None: + if ( + device is not None + or policy_device is not None + or output_device is not None + ): + raise ValueError( + "device_config is mutually exclusive with device, " + "policy_device, and output_device." + ) + policy_device = device_config.policy_device + output_device = device_config.server_output_device() self.policy_factory = policy_factory self.transport = transport if isinstance(mp_context, str):