Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/modules_inference_server.rst

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is worth a paragraph in the doc somewhere

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Core API
:template: rl_template_noinherit.rst

InferenceServer
InferenceServerConfig
InferenceDeviceConfig
ProcessInferenceServer
InferenceClient
InferenceTransport
Expand Down
43 changes: 43 additions & 0 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from torchrl.modules.inference_server import (
InferenceClient,
InferenceDeviceConfig,
InferenceServer,
InferenceServerConfig,
InferenceTransport,
MPTransport,
ProcessInferenceServer,
Expand Down Expand Up @@ -257,6 +259,23 @@ def test_stats_accounting(self):
assert stats["avg_batch_size"] > 0
assert stats["p95_forward_ms"] >= 0

def test_structured_config(self):
transport = ThreadingTransport()
policy = _make_policy()
server_config = InferenceServerConfig(max_batch_size=2, timeout=0.001)
device_config = InferenceDeviceConfig(policy_device="cpu", output_device="cpu")
with InferenceServer(
policy,
transport,
server_config=server_config,
device_config=device_config,
) as server:
client = transport.client()
result = client(TensorDict({"observation": torch.randn(4)}))
stats = server.stats()
assert result["action"].device.type == "cpu"
assert stats["requests"] == 1

@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA")
def test_cuda_policy_cpu_output(self):
Expand Down Expand Up @@ -1004,6 +1023,30 @@ def test_process_server_backend_smoke(self):
collector.shutdown()
assert total >= 20

def test_device_config_and_server_config(self):
"""Collector accepts structured device and server config objects."""
collector = AsyncBatchedCollector(
create_env_fn=[_counting_env_factory] * 2,
policy=_make_counting_policy(),
frames_per_batch=10,
total_frames=20,
server_config=InferenceServerConfig(max_batch_size=2),
device_config=InferenceDeviceConfig(
policy_device="cpu",
output_device="cpu",
env_device="cpu",
storing_device="cpu",
),
)
total = 0
for batch in collector:
assert batch.device is None or batch.device.type == "cpu"
total += batch.numel()
stats = collector.server_stats()
collector.shutdown()
assert total >= 20
assert stats["requests"] > 0


# =============================================================================
# Tests: SlotTransport
Expand Down
83 changes: 79 additions & 4 deletions torchrl/collectors/_async_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from torchrl.collectors._base import BaseCollector
from torchrl.envs import AsyncEnvPool, EnvBase
from torchrl.modules.inference_server import (
InferenceDeviceConfig,
InferenceServer,
InferenceServerConfig,
ProcessInferenceServer,
ThreadingTransport,
)
Expand Down Expand Up @@ -72,15 +74,17 @@ def _env_loop(
client: Callable | None,
result_queue: queue.Queue,
shutdown_event: threading.Event,
env_device: torch.device | None,
storing_device: torch.device | None,
):
"""Per-env worker thread using pool slot for env execution and InferenceServer for policy.
"""Per-env worker thread using a pool slot and inference-server policy.

Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and
one inference client. The pool handles the actual environment execution in
whatever backend it was configured with (threading, multiprocessing, etc.),
while this thread coordinates the send/recv cycle and inference submission.

reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ...
reset -> infer -> step_send -> step_recv -> put transition -> infer -> ...
"""
if client is None:
client = transport.client()
Expand All @@ -91,9 +95,13 @@ def _env_loop(
action_td = client(obs)

while not shutdown_event.is_set():
if env_device is not None:
action_td = action_td.to(env_device)
pool.async_step_and_maybe_reset_send(action_td, env_index=env_id)
cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id)
cur_td.set(_ENV_IDX_KEY, env_id)
if storing_device is not None:
cur_td = cur_td.to(storing_device)
result_queue.put(cur_td)
if shutdown_event.is_set():
break
Expand All @@ -104,7 +112,11 @@ def _env_loop(


class AsyncBatchedCollector(BaseCollector):
"""Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`.
"""Asynchronous collector with env slots and a policy server.

The collector pairs per-env coordinator threads with an
:class:`~torchrl.envs.AsyncEnvPool` and an
:class:`~torchrl.modules.InferenceServer`.

Unlike :class:`~torchrl.collectors.Collector`, this collector fully
decouples environment stepping from policy inference:
Expand Down Expand Up @@ -166,6 +178,18 @@ class AsyncBatchedCollector(BaseCollector):
output_device (torch.device or str, optional): device where action
TensorDicts are moved before being sent back to env workers.
Defaults to ``None``.
env_device (torch.device or str, optional): device used by env workers
for action TensorDicts before stepping. Defaults to ``None``.
storing_device (torch.device or str, optional): device used for
collected transitions yielded by this collector. Defaults to
``None``.
server_config (InferenceServerConfig, optional): structured server
batching and stats configuration. Mutually exclusive with
non-default batching keyword arguments.
device_config (InferenceDeviceConfig, optional): structured device
placement configuration. Mutually exclusive with ``device``,
``policy_device``, ``output_device``, ``env_device``, and
``storing_device``.
backend (str, optional): global default backend for both
environments and policy inference. Specific overrides
``env_backend`` and ``policy_backend`` take precedence when set.
Expand Down Expand Up @@ -255,6 +279,10 @@ def __init__(
create_env_kwargs: dict | list[dict] | None = None,
policy_device: torch.device | str | None = None,
output_device: torch.device | str | None = None,
env_device: torch.device | str | None = None,
storing_device: torch.device | str | None = None,
server_config: InferenceServerConfig | None = None,
device_config: InferenceDeviceConfig | None = None,
server_backend: Literal["thread", "process"] = "thread",
):
if policy is not None and policy_factory is not None:
Expand All @@ -266,6 +294,40 @@ def __init__(
"server_backend='process' requires policy_factory so the policy "
"can be constructed inside the server process."
)
if server_config is not None:
if (max_batch_size, min_batch_size, server_timeout) != (64, 1, 0.01):
raise ValueError(
"server_config is mutually exclusive with non-default "
"batching keyword arguments."
)
max_batch_size = server_config.max_batch_size
min_batch_size = server_config.min_batch_size
server_timeout = server_config.timeout
if device_config is not None:
if (
device is not None
or policy_device is not None
or output_device is not None
or env_device is not None
or storing_device is not None
):
raise ValueError(
"device_config is mutually exclusive with device, "
"policy_device, output_device, env_device, and "
"storing_device."
)
policy_device = device_config.policy_device
output_device = device_config.server_output_device()
env_device = device_config.env_device
storing_device = device_config.storing_device
else:
policy_device = device if policy_device is None else policy_device
if output_device is None and env_device is not None:
output_device = env_device
self._env_device = torch.device(env_device) if env_device is not None else None
self._storing_device = (
torch.device(storing_device) if storing_device is not None else None
)

# ---- resolve policy ---------------------------------------------------
self._policy_factory = policy_factory
Expand Down Expand Up @@ -309,7 +371,6 @@ def __init__(
self._transport = transport

# ---- build inference server -------------------------------------------
policy_device = device if policy_device is None else policy_device
if server_backend == "process":
self._server = ProcessInferenceServer(
policy_factory=policy_factory,
Expand All @@ -321,6 +382,12 @@ def __init__(
output_device=output_device,
weight_sync=weight_sync,
weight_sync_model_id=weight_sync_model_id,
collect_stats=(
True if server_config is None else server_config.collect_stats
),
stats_window_size=(
1024 if server_config is None else server_config.stats_window_size
),
)
else:
self._server = InferenceServer(
Expand All @@ -333,6 +400,12 @@ def __init__(
output_device=output_device,
weight_sync=weight_sync,
weight_sync_model_id=weight_sync_model_id,
collect_stats=(
True if server_config is None else server_config.collect_stats
),
stats_window_size=(
1024 if server_config is None else server_config.stats_window_size
),
)

# ---- collector settings -----------------------------------------------
Expand Down Expand Up @@ -401,6 +474,8 @@ def _ensure_started(self) -> None:
"client": self._clients[i],
"result_queue": self._result_queue,
"shutdown_event": self._shutdown_event,
"env_device": self._env_device,
"storing_device": self._storing_device,
},
daemon=True,
name=f"AsyncBatchedCollector-env-{i}",
Expand Down
6 changes: 6 additions & 0 deletions torchrl/modules/inference_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from torchrl.modules.inference_server._config import (
InferenceDeviceConfig,
InferenceServerConfig,
)
from torchrl.modules.inference_server._monarch import MonarchTransport
from torchrl.modules.inference_server._mp import MPTransport
from torchrl.modules.inference_server._ray import RayTransport
Expand All @@ -17,7 +21,9 @@

__all__ = [
"InferenceClient",
"InferenceDeviceConfig",
"InferenceServer",
"InferenceServerConfig",
"InferenceTransport",
"MonarchTransport",
"MPTransport",
Expand Down
88 changes: 88 additions & 0 deletions torchrl/modules/inference_server/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass

import torch


def _as_device(device: torch.device | str | None) -> torch.device | None:
if device is None:
return None
return torch.device(device)


@dataclass
class InferenceDeviceConfig:
"""Device placement for asynchronous policy-server collection.

This config separates the devices used by the environment, the remote
policy, the actor-side action TensorDict, and the returned collector batch.

Args:
policy_device (torch.device or str, optional): device that owns the
policy and receives batched server inputs.
output_device (torch.device or str, optional): device for inference
results returned by the server.
env_device (torch.device or str, optional): device used by env workers
when stepping environments. If ``output_device`` is omitted, this is
the natural device for returned actions.
storing_device (torch.device or str, optional): device used for
collected transitions yielded by the collector.

Examples:
>>> from torchrl.modules.inference_server import InferenceDeviceConfig
>>> config = InferenceDeviceConfig(policy_device="cpu", env_device="cpu")
>>> config.policy_device
device(type='cpu')
"""

policy_device: torch.device | str | None = None
output_device: torch.device | str | None = None
env_device: torch.device | str | None = None
storing_device: torch.device | str | None = None

def __post_init__(self) -> None:
self.policy_device = _as_device(self.policy_device)
self.output_device = _as_device(self.output_device)
self.env_device = _as_device(self.env_device)
self.storing_device = _as_device(self.storing_device)

def server_output_device(self) -> torch.device | None:
"""Return the actor-side device expected from the policy server."""
if self.output_device is not None:
return self.output_device
return self.env_device


@dataclass
class InferenceServerConfig:
"""Server-side batching, timeout, and instrumentation settings.

Args:
max_batch_size (int, optional): maximum number of requests per forward
pass. Defaults to ``64``.
min_batch_size (int, optional): minimum number of requests to
accumulate after the first request arrives. Defaults to ``1``.
timeout (float, optional): seconds to wait for more requests before
flushing a partial batch. Defaults to ``0.01``.
collect_stats (bool, optional): whether to collect lightweight
throughput and latency stats. Defaults to ``True``.
stats_window_size (int, optional): number of recent timing samples kept
for percentile stats. Defaults to ``1024``.

Examples:
>>> from torchrl.modules.inference_server import InferenceServerConfig
>>> config = InferenceServerConfig(max_batch_size=8, timeout=0.001)
>>> config.max_batch_size
8
"""

max_batch_size: int = 64
min_batch_size: int = 1
timeout: float = 0.01
collect_stats: bool = True
stats_window_size: int = 1024
Loading
Loading