Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/modules_inference_server.rst

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is worth a paragraph in the doc somewhere

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Core API
:template: rl_template_noinherit.rst

InferenceServer
InferenceServerConfig
InferenceDeviceConfig
ProcessInferenceServer
InferenceClient
InferenceTransport
Expand Down
43 changes: 43 additions & 0 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from torchrl.modules.inference_server import (
InferenceClient,
InferenceDeviceConfig,
InferenceServer,
InferenceServerConfig,
InferenceTransport,
MPTransport,
ProcessInferenceServer,
Expand Down Expand Up @@ -257,6 +259,23 @@ def test_stats_accounting(self):
assert stats["avg_batch_size"] > 0
assert stats["p95_forward_ms"] >= 0

def test_structured_config(self):
transport = ThreadingTransport()
policy = _make_policy()
server_config = InferenceServerConfig(max_batch_size=2, timeout=0.001)
device_config = InferenceDeviceConfig(policy_device="cpu", output_device="cpu")
with InferenceServer(
policy,
transport,
server_config=server_config,
device_config=device_config,
) as server:
client = transport.client()
result = client(TensorDict({"observation": torch.randn(4)}))
stats = server.stats()
assert result["action"].device.type == "cpu"
assert stats["requests"] == 1

@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA")
def test_cuda_policy_cpu_output(self):
Expand Down Expand Up @@ -1004,6 +1023,30 @@ def test_process_server_backend_smoke(self):
collector.shutdown()
assert total >= 20

def test_device_config_and_server_config(self):
"""Collector accepts structured device and server config objects."""
collector = AsyncBatchedCollector(
create_env_fn=[_counting_env_factory] * 2,
policy=_make_counting_policy(),
frames_per_batch=10,
total_frames=20,
server_config=InferenceServerConfig(max_batch_size=2),
device_config=InferenceDeviceConfig(
policy_device="cpu",
output_device="cpu",
env_device="cpu",
storing_device="cpu",
),
)
total = 0
for batch in collector:
assert batch.device is None or batch.device.type == "cpu"
total += batch.numel()
stats = collector.server_stats()
collector.shutdown()
assert total >= 20
assert stats["requests"] > 0


# =============================================================================
# Tests: SlotTransport
Expand Down
83 changes: 79 additions & 4 deletions torchrl/collectors/_async_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from torchrl.collectors._base import BaseCollector
from torchrl.envs import AsyncEnvPool, EnvBase
from torchrl.modules.inference_server import (
InferenceDeviceConfig,
InferenceServer,
InferenceServerConfig,
ProcessInferenceServer,
ThreadingTransport,
)
Expand Down Expand Up @@ -72,15 +74,17 @@ def _env_loop(
client: Callable | None,
result_queue: queue.Queue,
shutdown_event: threading.Event,
env_device: torch.device | None,
storing_device: torch.device | None,
):
"""Per-env worker thread using pool slot for env execution and InferenceServer for policy.
"""Per-env worker thread using a pool slot and inference-server policy.

Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and
one inference client. The pool handles the actual environment execution in
whatever backend it was configured with (threading, multiprocessing, etc.),
while this thread coordinates the send/recv cycle and inference submission.

reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ...
reset -> infer -> step_send -> step_recv -> put transition -> infer -> ...
"""
if client is None:
client = transport.client()
Expand All @@ -91,9 +95,13 @@ def _env_loop(
action_td = client(obs)

while not shutdown_event.is_set():
if env_device is not None:
action_td = action_td.to(env_device)
pool.async_step_and_maybe_reset_send(action_td, env_index=env_id)
cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id)
cur_td.set(_ENV_IDX_KEY, env_id)
if storing_device is not None:
cur_td = cur_td.to(storing_device)
result_queue.put(cur_td)
if shutdown_event.is_set():
break
Expand All @@ -104,7 +112,11 @@ def _env_loop(


class AsyncBatchedCollector(BaseCollector):
"""Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`.
"""Asynchronous collector with env slots and a policy server.

The collector pairs per-env coordinator threads with an
:class:`~torchrl.envs.AsyncEnvPool` and an
:class:`~torchrl.modules.InferenceServer`.

Unlike :class:`~torchrl.collectors.Collector`, this collector fully
decouples environment stepping from policy inference:
Expand Down Expand Up @@ -166,6 +178,18 @@ class AsyncBatchedCollector(BaseCollector):
output_device (torch.device or str, optional): device where action
TensorDicts are moved before being sent back to env workers.
Defaults to ``None``.
env_device (torch.device or str, optional): device used by env workers
for action TensorDicts before stepping. Defaults to ``None``.
storing_device (torch.device or str, optional): device used for
collected transitions yielded by this collector. Defaults to
``None``.
server_config (InferenceServerConfig, optional): structured server
batching and stats configuration. Mutually exclusive with
non-default batching keyword arguments.
device_config (InferenceDeviceConfig, optional): structured device
placement configuration. Mutually exclusive with ``device``,
``policy_device``, ``output_device``, ``env_device``, and
``storing_device``.
backend (str, optional): global default backend for both
environments and policy inference. Specific overrides
``env_backend`` and ``policy_backend`` take precedence when set.
Expand Down Expand Up @@ -255,6 +279,10 @@ def __init__(
create_env_kwargs: dict | list[dict] | None = None,
policy_device: torch.device | str | None = None,
output_device: torch.device | str | None = None,
env_device: torch.device | str | None = None,
storing_device: torch.device | str | None = None,
server_config: InferenceServerConfig | None = None,
device_config: InferenceDeviceConfig | None = None,
server_backend: Literal["thread", "process"] = "thread",
):
if policy is not None and policy_factory is not None:
Expand All @@ -266,6 +294,40 @@ def __init__(
"server_backend='process' requires policy_factory so the policy "
"can be constructed inside the server process."
)
if server_config is not None:
if (max_batch_size, min_batch_size, server_timeout) != (64, 1, 0.01):
raise ValueError(
"server_config is mutually exclusive with non-default "
"batching keyword arguments."
)
max_batch_size = server_config.max_batch_size
min_batch_size = server_config.min_batch_size
server_timeout = server_config.timeout
if device_config is not None:
if (
device is not None
or policy_device is not None
or output_device is not None
or env_device is not None
or storing_device is not None
):
raise ValueError(
"device_config is mutually exclusive with device, "
"policy_device, output_device, env_device, and "
"storing_device."
)
policy_device = device_config.policy_device
output_device = device_config.server_output_device()
env_device = device_config.env_device
storing_device = device_config.storing_device
else:
policy_device = device if policy_device is None else policy_device
if output_device is None and env_device is not None:
output_device = env_device
self._env_device = torch.device(env_device) if env_device is not None else None
self._storing_device = (
torch.device(storing_device) if storing_device is not None else None
)

# ---- resolve policy ---------------------------------------------------
self._policy_factory = policy_factory
Expand Down Expand Up @@ -309,7 +371,6 @@ def __init__(
self._transport = transport

# ---- build inference server -------------------------------------------
policy_device = device if policy_device is None else policy_device
if server_backend == "process":
self._server = ProcessInferenceServer(
policy_factory=policy_factory,
Expand All @@ -321,6 +382,12 @@ def __init__(
output_device=output_device,
weight_sync=weight_sync,
weight_sync_model_id=weight_sync_model_id,
collect_stats=(
True if server_config is None else server_config.collect_stats
),
stats_window_size=(
1024 if server_config is None else server_config.stats_window_size
),
)
else:
self._server = InferenceServer(
Expand All @@ -333,6 +400,12 @@ def __init__(
output_device=output_device,
weight_sync=weight_sync,
weight_sync_model_id=weight_sync_model_id,
collect_stats=(
True if server_config is None else server_config.collect_stats
),
stats_window_size=(
1024 if server_config is None else server_config.stats_window_size
),
)

# ---- collector settings -----------------------------------------------
Expand Down Expand Up @@ -401,6 +474,8 @@ def _ensure_started(self) -> None:
"client": self._clients[i],
"result_queue": self._result_queue,
"shutdown_event": self._shutdown_event,
"env_device": self._env_device,
"storing_device": self._storing_device,
},
daemon=True,
name=f"AsyncBatchedCollector-env-{i}",
Expand Down
6 changes: 6 additions & 0 deletions torchrl/modules/inference_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from torchrl.modules.inference_server._config import (
InferenceDeviceConfig,
InferenceServerConfig,
)
from torchrl.modules.inference_server._monarch import MonarchTransport
from torchrl.modules.inference_server._mp import MPTransport
from torchrl.modules.inference_server._ray import RayTransport
Expand All @@ -17,7 +21,9 @@

__all__ = [
"InferenceClient",
"InferenceDeviceConfig",
"InferenceServer",
"InferenceServerConfig",
"InferenceTransport",
"MonarchTransport",
"MPTransport",
Expand Down
88 changes: 88 additions & 0 deletions torchrl/modules/inference_server/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass

import torch


def _as_device(device: torch.device | str | None) -> torch.device | None:
if device is None:
return None
return torch.device(device)


@dataclass
class InferenceDeviceConfig:
"""Device placement for asynchronous policy-server collection.

This config separates the devices used by the environment, the remote
policy, the actor-side action TensorDict, and the returned collector batch.

Args:
policy_device (torch.device or str, optional): device that owns the
policy and receives batched server inputs.
output_device (torch.device or str, optional): device for inference
results returned by the server.
env_device (torch.device or str, optional): device used by env workers
when stepping environments. If ``output_device`` is omitted, this is
the natural device for returned actions.
storing_device (torch.device or str, optional): device used for
collected transitions yielded by the collector.

Examples:
>>> from torchrl.modules.inference_server import InferenceDeviceConfig
>>> config = InferenceDeviceConfig(policy_device="cpu", env_device="cpu")
>>> config.policy_device
device(type='cpu')
"""

policy_device: torch.device | str | None = None
output_device: torch.device | str | None = None
env_device: torch.device | str | None = None
storing_device: torch.device | str | None = None

def __post_init__(self) -> None:
self.policy_device = _as_device(self.policy_device)
self.output_device = _as_device(self.output_device)
self.env_device = _as_device(self.env_device)
self.storing_device = _as_device(self.storing_device)

def server_output_device(self) -> torch.device | None:
"""Return the actor-side device expected from the policy server."""
if self.output_device is not None:
return self.output_device
return self.env_device


@dataclass
class InferenceServerConfig:
"""Server-side batching, timeout, and instrumentation settings.

Args:
max_batch_size (int, optional): maximum number of requests per forward
pass. Defaults to ``64``.
min_batch_size (int, optional): minimum number of requests to
accumulate after the first request arrives. Defaults to ``1``.
timeout (float, optional): seconds to wait for more requests before
flushing a partial batch. Defaults to ``0.01``.
collect_stats (bool, optional): whether to collect lightweight
throughput and latency stats. Defaults to ``True``.
stats_window_size (int, optional): number of recent timing samples kept
for percentile stats. Defaults to ``1024``.

Examples:
>>> from torchrl.modules.inference_server import InferenceServerConfig
>>> config = InferenceServerConfig(max_batch_size=8, timeout=0.001)
>>> config.max_batch_size
8
"""

max_batch_size: int = 64
min_batch_size: int = 1
timeout: float = 0.01
collect_stats: bool = True
stats_window_size: int = 1024
Loading
Loading