From 459fa967fff5bed3ead0abfb9cca5d4399858c22 Mon Sep 17 00:00:00 2001 From: Achintya P Date: Mon, 22 Jun 2026 21:27:07 -0700 Subject: [PATCH 1/4] [Feature] OfflineToOnlineTrainer + sota script for offline->online RL Follow-up to the OfflineToOnlineReplayBuffer PR: a SAC trainer that drives the offline-pretrain -> online-finetune transition, plus a standalone sota-implementations script. - OfflineToOnlineTrainer (subclasses SACTrainer): routes collected experience to the online buffer (pre_epoch), samples a mixed offline/online batch (process_optim_batch), and anneals the offline fraction to zero over anneal_frames (post_steps). Backed by two reusable hooks: OfflineToOnlineReplayBufferHook (projects online transitions onto the offline dataset schema so the mixed-batch concat stays valid) and OfflineToOnlineAnnealHook. - sota-implementations/offline_to_online/train.py: a self-contained SAC offline->online script (offline dataset via d4rl:/minari: string). - Tests: hook + flow tests and a gated functional train() run on Pendulum. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../offline_to_online/train.py | 159 ++++++++++++ test/test_offline_to_online.py | 236 ++++++++++++++++++ torchrl/trainers/algorithms/__init__.py | 2 + .../trainers/algorithms/offline_to_online.py | 231 +++++++++++++++++ 4 files changed, 628 insertions(+) create mode 100644 sota-implementations/offline_to_online/train.py create mode 100644 torchrl/trainers/algorithms/offline_to_online.py diff --git a/sota-implementations/offline_to_online/train.py b/sota-implementations/offline_to_online/train.py new file mode 100644 index 00000000000..07404f7e483 --- /dev/null +++ b/sota-implementations/offline_to_online/train.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Offline-to-online SAC fine-tuning. + +Warm-starts SAC on an offline dataset (D4RL/Minari) and fine-tunes it online via +:class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`, sampling a mixed +offline/online batch whose offline fraction is annealed to zero over +``--anneal-frames`` collected frames. + +Example:: + + python train.py --dataset d4rl:halfcheetah-medium-v2 --env HalfCheetah-v4 + python train.py --dataset minari:mujoco/halfcheetah/expert-v0 --total-frames 200000 + +Requires the dataset backend (``pip install d4rl`` or ``pip install minari``) and +the matching MuJoCo environment. +""" + +from __future__ import annotations + +import argparse + +import torch +from tensordict.nn import NormalParamExtractor, TensorDictModule +from torch import nn + +from torchrl.collectors import Collector +from torchrl.data import OfflineToOnlineReplayBuffer +from torchrl.data.datasets.utils import load_dataset +from torchrl.envs import DoubleToFloat, GymEnv, TransformedEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.objectives import SACLoss, SoftUpdate +from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer + + +def make_sac_modules(env, num_cells, device): + obs_dim = env.observation_spec["observation"].shape[-1] + action_dim = env.action_spec.shape[-1] + + actor_net = nn.Sequential( + MLP( + in_features=obs_dim, + out_features=2 * action_dim, + num_cells=num_cells, + device=device, + ), + NormalParamExtractor(), + ) + actor = ProbabilisticActor( + module=TensorDictModule( + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] + ), + in_keys=["loc", "scale"], + spec=env.action_spec, + distribution_class=TanhNormal, + distribution_kwargs={ + "low": env.action_spec.space.low, + "high": env.action_spec.space.high, + }, + return_log_prob=True, + ) + qvalue = ValueOperator( + MLP( + in_features=obs_dim + action_dim, + out_features=1, + num_cells=num_cells, + device=device, + ), + in_keys=["observation", "action"], + out_keys=["state_action_value"], + ) + return actor, qvalue + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--env", default="HalfCheetah-v4", help="online gym env id") + parser.add_argument( + "--dataset", + default="d4rl:halfcheetah-medium-v2", + help="offline dataset id ('d4rl:' or 'minari:')", + ) + parser.add_argument("--total-frames", type=int, default=1_000_000) + parser.add_argument("--frames-per-batch", type=int, default=1000) + parser.add_argument( + "--anneal-frames", + type=int, + default=None, + help="frames over which the offline fraction decays to 0 (default: half " + "of --total-frames)", + ) + parser.add_argument("--offline-fraction", type=float, default=0.5) + parser.add_argument("--online-capacity", type=int, default=1_000_000) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--utd", type=int, default=64, help="optim steps per batch") + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--num-cells", type=int, nargs="+", default=[256, 256]) + parser.add_argument("--tau", type=float, default=0.001) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + torch.manual_seed(args.seed) + device = torch.device(args.device) + + # Online environment. + env = TransformedEnv(GymEnv(args.env, device=device), DoubleToFloat()) + env.set_seed(args.seed) + + # SAC agent. + actor, qvalue = make_sac_modules(env, args.num_cells, device) + loss = SACLoss(actor_network=actor, qvalue_network=qvalue) + loss.make_value_estimator(gamma=0.99) + target_net_updater = SoftUpdate(loss, tau=args.tau) + optimizer = torch.optim.Adam(loss.parameters(), lr=args.lr) + + # Immutable offline dataset (DoubleToFloat to match the online float32 stream) + # paired with a growing online buffer. + offline = load_dataset(args.dataset) + offline.append_transform(DoubleToFloat()) + replay_buffer = OfflineToOnlineReplayBuffer( + offline_dataset=offline, + online_capacity=args.online_capacity, + offline_fraction=args.offline_fraction, + batch_size=args.batch_size, + ) + + collector = Collector( + env, + actor, + frames_per_batch=args.frames_per_batch, + total_frames=args.total_frames, + init_random_frames=0, # the offline dataset already warm-starts learning + device=device, + ) + + anneal_frames = ( + args.anneal_frames if args.anneal_frames is not None else args.total_frames // 2 + ) + trainer = OfflineToOnlineTrainer( + collector=collector, + total_frames=args.total_frames, + frame_skip=1, + optim_steps_per_batch=args.utd, + loss_module=loss, + replay_buffer=replay_buffer, + anneal_frames=anneal_frames, + batch_size=args.batch_size, + optimizer=optimizer, + target_net_updater=target_net_updater, + clip_grad_norm=False, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/test/test_offline_to_online.py b/test/test_offline_to_online.py index cb10cd1e2cb..c809c14e3fa 100644 --- a/test/test_offline_to_online.py +++ b/test/test_offline_to_online.py @@ -14,6 +14,15 @@ from torchrl.data.datasets import utils as dataset_utils from torchrl.data.datasets.utils import load_dataset, register_dataset from torchrl.data.replay_buffers.offline_to_online import prefill_replay_buffer +from torchrl.envs.libs.gym import _has_gym + +# Running a SAC loss requires a tensordict new enough to support +# ``to_module(preserve_module_state=...)``; the offline-to-online wiring itself +# does not. +_LOSS_RUNNABLE = ( + "preserve_module_state" + in __import__("inspect").signature(TensorDict.to_module).parameters +) def _make_offline_buffer(n: int = 1000, obs_dim: int = 4, action_dim: int = 2): @@ -447,6 +456,233 @@ def fake_d4rl(dataset_id, **kwargs): assert captured["kwargs"] == {"batch_size": 32} +class _StubTrainer: + """Minimal stand-in exposing the ``collected_frames`` the anneal hook reads.""" + + def __init__(self, collected_frames: int = 0): + self.collected_frames = collected_frames + + +class TestOfflineToOnlineReplayBufferHook: + def test_extend_uses_collector_mask(self): + from torchrl.trainers.algorithms.offline_to_online import ( + OfflineToOnlineReplayBufferHook, + ) + + offline = _make_offline_buffer() + rb = OfflineToOnlineReplayBuffer( + offline_dataset=offline, online_capacity=500, batch_size=16 + ) + hook = OfflineToOnlineReplayBufferHook(rb) + mask = torch.ones(2, 5, dtype=torch.bool) + mask[0, 3:] = False # 2 invalid rows -> 8 valid + data = TensorDict( + { + "observation": torch.randn(2, 5, 4), + ("collector", "mask"): mask, + }, + batch_size=[2, 5], + ) + hook.extend(data) + assert len(rb.online_buffer) == 8 + # collector bookkeeping is not stored + assert "collector" not in rb.online_buffer.sample(4).keys() + + def test_state_dict_roundtrip(self): + from torchrl.trainers.algorithms.offline_to_online import ( + OfflineToOnlineReplayBufferHook, + ) + + rb = OfflineToOnlineReplayBuffer( + offline_dataset=_make_offline_buffer(), online_capacity=500, batch_size=16 + ) + hook = OfflineToOnlineReplayBufferHook(rb) + hook.extend(_make_online_data(20)) + + rb2 = OfflineToOnlineReplayBuffer( + offline_dataset=_make_offline_buffer(), online_capacity=500, batch_size=16 + ) + hook2 = OfflineToOnlineReplayBufferHook(rb2) + hook2.load_state_dict(hook.state_dict()) + assert len(rb2.online_buffer) == 20 + + +class TestOfflineToOnlineAnnealHook: + def test_anneal_decays_fraction(self): + from torchrl.trainers.algorithms.offline_to_online import ( + OfflineToOnlineAnnealHook, + ) + + rb = OfflineToOnlineReplayBuffer( + offline_dataset=_make_offline_buffer(), + online_capacity=500, + offline_fraction=0.8, + batch_size=16, + ) + stub = _StubTrainer() + hook = OfflineToOnlineAnnealHook(stub, rb, anneal_frames=100) + + stub.collected_frames = 0 + hook() + assert rb.offline_fraction == pytest.approx(0.8) + + stub.collected_frames = 50 + hook() + assert rb.offline_fraction == pytest.approx(0.4) + + stub.collected_frames = 100 + hook() + assert rb.offline_fraction == pytest.approx(0.0) + + # clamps at 0 past anneal_frames + stub.collected_frames = 200 + hook() + assert rb.offline_fraction == 0.0 + + +class TestOfflineToOnlineTrainer: + def test_requires_offline_to_online_buffer(self): + from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer + + plain = ReplayBuffer(storage=LazyTensorStorage(100)) + with pytest.raises(TypeError, match="OfflineToOnlineReplayBuffer"): + OfflineToOnlineTrainer( + collector=None, + total_frames=1, + frame_skip=1, + optim_steps_per_batch=1, + loss_module=None, + replay_buffer=plain, + ) + + def test_hooks_drive_offline_online_flow(self): + """The three hooks together grow the online buffer, keep the mixed batch + flat, and anneal the offline fraction -- the data path the trainer runs, + exercised without a loss so it is independent of the SAC/tensordict + version.""" + from torchrl.trainers.algorithms.offline_to_online import ( + OfflineToOnlineAnnealHook, + OfflineToOnlineReplayBufferHook, + ) + + rb = OfflineToOnlineReplayBuffer( + offline_dataset=_make_offline_buffer(), + online_capacity=500, + offline_fraction=0.5, + batch_size=16, + ) + rb_hook = OfflineToOnlineReplayBufferHook(rb, batch_size=16, device="cpu") + stub = _StubTrainer() + anneal = OfflineToOnlineAnnealHook(stub, rb, anneal_frames=100) + + for step in (20, 40, 60, 80, 100): + rb_hook.extend(_make_online_data(20)) # pre_epoch + sample = rb_hook.sample(None) # process_optim_batch + assert sample.batch_size == torch.Size([16]) + stub.collected_frames = step + anneal() # post_steps + + assert len(rb.online_buffer) == 100 + assert rb.offline_fraction == pytest.approx(0.0) + + @pytest.mark.skipif( + not (_has_gym and _LOSS_RUNNABLE), + reason="needs gym and a tensordict supporting to_module(preserve_module_state)", + ) + def test_train_grows_online_and_anneals(self, tmp_path): + import warnings + + from tensordict.nn import NormalParamExtractor, TensorDictModule + from torch import nn + + from torchrl.collectors import Collector + from torchrl.envs.libs.gym import GymEnv + from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + from torchrl.objectives import SACLoss, SoftUpdate + from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer + + torch.manual_seed(0) + env = GymEnv("Pendulum-v1") + obs_dim = env.observation_spec["observation"].shape[-1] + action_dim = env.action_spec.shape[-1] + + actor_net = nn.Sequential( + MLP(in_features=obs_dim, out_features=2 * action_dim, num_cells=[32, 32]), + NormalParamExtractor(), + ) + actor_module = TensorDictModule( + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=actor_module, + in_keys=["loc", "scale"], + spec=env.action_spec, + distribution_class=TanhNormal, + distribution_kwargs={ + "low": env.action_spec.space.low, + "high": env.action_spec.space.high, + }, + return_log_prob=True, + ) + qvalue = ValueOperator( + MLP(in_features=obs_dim + action_dim, out_features=1, num_cells=[32, 32]), + in_keys=["observation", "action"], + out_keys=["state_action_value"], + ) + + loss = SACLoss(actor_network=actor, qvalue_network=qvalue) + loss.make_value_estimator(gamma=0.99) + target_updater = SoftUpdate(loss, eps=0.99) + + total_frames = 60 + frames_per_batch = 20 + collector = Collector( + env, + actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + init_random_frames=0, + ) + + # Seed an offline dataset from the same env so its keys match the online + # transitions (no Minari/D4RL required). + offline = ReplayBuffer(storage=LazyTensorStorage(200)) + offline.extend(env.rollout(50).reshape(-1).exclude("collector")) + + rb = OfflineToOnlineReplayBuffer( + offline_dataset=offline, + online_capacity=200, + offline_fraction=0.5, + batch_size=16, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + trainer = OfflineToOnlineTrainer( + collector=collector, + total_frames=total_frames, + frame_skip=1, + optim_steps_per_batch=1, + loss_module=loss, + replay_buffer=rb, + anneal_frames=total_frames, + optimizer=torch.optim.Adam(loss.parameters(), lr=1e-3), + target_net_updater=target_updater, + progress_bar=False, + enable_logging=False, + ) + # extend (pre_epoch), sample (process_optim_batch), anneal (post_steps) + assert len(trainer._pre_epoch_ops) >= 1 + assert len(trainer._process_optim_batch_ops) >= 1 + assert len(trainer._post_steps_ops) >= 1 + + trainer.train() + + # Online experience accumulated and the offline fraction annealed away. + assert len(rb.online_buffer) > 0 + assert rb.offline_fraction < 0.5 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/trainers/algorithms/__init__.py b/torchrl/trainers/algorithms/__init__.py index cab87397c77..65d088513c8 100644 --- a/torchrl/trainers/algorithms/__init__.py +++ b/torchrl/trainers/algorithms/__init__.py @@ -9,6 +9,7 @@ from .ddpg import DDPGTrainer from .dqn import DQNTrainer from .iql import IQLTrainer +from .offline_to_online import OfflineToOnlineTrainer from .ppo import PPOTrainer from .sac import SACTrainer @@ -17,6 +18,7 @@ "DDPGTrainer", "DQNTrainer", "IQLTrainer", + "OfflineToOnlineTrainer", "PPOTrainer", "SACTrainer", ] diff --git a/torchrl/trainers/algorithms/offline_to_online.py b/torchrl/trainers/algorithms/offline_to_online.py new file mode 100644 index 00000000000..dfae7e5937c --- /dev/null +++ b/torchrl/trainers/algorithms/offline_to_online.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pathlib + +from collections.abc import Callable + +from tensordict import TensorDictBase + +from torchrl.collectors import BaseCollector +from torchrl.data.replay_buffers.offline_to_online import OfflineToOnlineReplayBuffer +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import TargetNetUpdater +from torchrl.record.loggers import Logger +from torchrl.trainers.algorithms.sac import SACTrainer +from torchrl.trainers.trainers import TrainerHookBase + +__all__ = [ + "OfflineToOnlineReplayBufferHook", + "OfflineToOnlineAnnealHook", + "OfflineToOnlineTrainer", +] + + +class OfflineToOnlineReplayBufferHook(TrainerHookBase): + """Trainer hook driving an :class:`~torchrl.data.OfflineToOnlineReplayBuffer`. + + Routes freshly collected experience to the online buffer on ``pre_epoch`` and + draws a mixed offline/online batch on ``process_optim_batch``. Online + transitions are projected onto the offline dataset's key schema before being + stored, so the offline/online concat in + :meth:`OfflineToOnlineReplayBuffer.sample` does not raise on the policy + outputs (``loc``/``scale``/``log_prob``) and ``collector`` subtree the + offline dataset lacks. + + Keyword Args: + batch_size (int, optional): batch size for :meth:`sample`; falls back to + the buffer's configured ``batch_size``. + device (device, optional): device the sampled batch is moved to. + align_to_offline_keys (bool, optional): project stored online + transitions onto the offline schema (default ``True``). + """ + + def __init__( + self, + replay_buffer: OfflineToOnlineReplayBuffer, + *, + batch_size: int | None = None, + device=None, + align_to_offline_keys: bool = True, + ) -> None: + self.replay_buffer = replay_buffer + self.batch_size = batch_size + self.device = device + self.align_to_offline_keys = align_to_offline_keys + self._offline_keys = None + + def _aligned_keys(self) -> list | None: + if not self.align_to_offline_keys: + return None + if self._offline_keys is None: + offline = self.replay_buffer.offline_buffer + if not len(offline): + return None + probe = offline.sample(1) + self._offline_keys = list(probe.keys(include_nested=True, leaves_only=True)) + return self._offline_keys + + def extend(self, batch: TensorDictBase) -> TensorDictBase: + if ("collector", "mask") in batch.keys(True): + batch = batch[batch.get(("collector", "mask"))] + else: + batch = batch.reshape(-1) + keys = self._aligned_keys() + if keys is not None: + batch = batch.select(*keys, strict=False) + elif "collector" in batch.keys(): + batch = batch.exclude("collector") + batch = batch.cpu() + self.replay_buffer.extend(batch) + return batch + + def sample(self, batch: TensorDictBase) -> TensorDictBase: + sample = self.replay_buffer.sample(self.batch_size) + return sample.to(self.device) if self.device is not None else sample + + def state_dict(self) -> dict: + return {"online_buffer": self.replay_buffer.online_buffer.state_dict()} + + def load_state_dict(self, state_dict: dict) -> None: + self.replay_buffer.online_buffer.load_state_dict(state_dict["online_buffer"]) + + def register(self, trainer, name: str = "replay_buffer") -> None: + trainer.register_op("pre_epoch", self.extend) + trainer.register_op("process_optim_batch", self.sample) + trainer.register_module(name, self) + + +class OfflineToOnlineAnnealHook(TrainerHookBase): + """Linearly decays the buffer's offline sampling fraction during training. + + Once per collected batch (``post_steps``) it calls + :meth:`OfflineToOnlineReplayBuffer.anneal` with the trainer's current + ``collected_frames``, so sampling shifts from offline-dominant to purely + online over ``anneal_frames`` frames. + """ + + def __init__( + self, + trainer, + replay_buffer: OfflineToOnlineReplayBuffer, + anneal_frames: int, + ) -> None: + self.trainer = trainer + self.replay_buffer = replay_buffer + self.anneal_frames = anneal_frames + + def __call__(self) -> None: + self.replay_buffer.anneal(self.trainer.collected_frames, self.anneal_frames) + + def state_dict(self) -> dict: + return {} + + def load_state_dict(self, state_dict: dict) -> None: + pass + + def register(self, trainer, name: str = "offline_to_online_anneal") -> None: + trainer.register_op("post_steps", self) + trainer.register_module(name, self) + + +class OfflineToOnlineTrainer(SACTrainer): + """A SAC trainer for the offline-pretrain -> online-finetune transition. + + Builds on :class:`~torchrl.trainers.algorithms.SACTrainer`, swapping the + plain replay buffer for an :class:`~torchrl.data.OfflineToOnlineReplayBuffer`. + Each collected batch is routed to the online buffer while optimization + samples a mixed batch whose offline fraction is linearly annealed to zero + over ``anneal_frames`` frames -- warm-starting the policy on offline data + and smoothly handing it over to its own online experience. All other SAC + behaviour (target-net updates, weight sync, logging) is inherited. + + Args: + collector (BaseCollector): the data collector for online interactions. + total_frames (int): total number of frames to collect. + frame_skip (int): frames skipped between policy updates. + optim_steps_per_batch (int): optimization steps per collected batch. + loss_module (LossModule): the SAC loss module. + replay_buffer (OfflineToOnlineReplayBuffer): the offline-to-online buffer. + + Keyword Args: + anneal_frames (int, optional): frames over which ``offline_fraction`` + decays to 0. Defaults to ``total_frames``; pass ``<= 0`` to keep the + fraction fixed. + batch_size (int, optional): replay-buffer sampling batch size. + + See :class:`~torchrl.trainers.algorithms.SACTrainer` for the remaining + keyword arguments. + + .. note:: Experimental/prototype feature; the API may change. + """ + + def __init__( + self, + *, + collector: BaseCollector, + total_frames: int, + frame_skip: int, + optim_steps_per_batch: int, + loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase], + replay_buffer: OfflineToOnlineReplayBuffer, + anneal_frames: int | None = None, + batch_size: int | None = None, + optimizer=None, + logger: Logger | None = None, + clip_grad_norm: bool = True, + clip_norm: float | None = None, + progress_bar: bool = True, + seed: int | None = None, + save_trainer_interval: int = 10000, + log_interval: int = 10000, + save_trainer_file: str | pathlib.Path | None = None, + enable_logging: bool = True, + target_net_updater: TargetNetUpdater | None = None, + ) -> None: + if not isinstance(replay_buffer, OfflineToOnlineReplayBuffer): + raise TypeError( + "OfflineToOnlineTrainer requires an OfflineToOnlineReplayBuffer, " + f"got {type(replay_buffer).__name__}." + ) + + # Let SACTrainer wire up everything except the replay buffer (its + # ReplayBufferTrainer assumes a sampler/priority API the offline-to-online + # buffer does not expose); we register our own RB + annealing hooks below. + super().__init__( + collector=collector, + total_frames=total_frames, + frame_skip=frame_skip, + optim_steps_per_batch=optim_steps_per_batch, + loss_module=loss_module, + optimizer=optimizer, + logger=logger, + clip_grad_norm=clip_grad_norm, + clip_norm=clip_norm, + progress_bar=progress_bar, + seed=seed, + save_trainer_interval=save_trainer_interval, + log_interval=log_interval, + save_trainer_file=save_trainer_file, + replay_buffer=None, + enable_logging=enable_logging, + target_net_updater=target_net_updater, + async_collection=False, + ) + + self.replay_buffer = replay_buffer + self.anneal_frames = total_frames if anneal_frames is None else anneal_frames + + device = getattr(replay_buffer.online_buffer.storage, "device", "cpu") + OfflineToOnlineReplayBufferHook( + replay_buffer, batch_size=batch_size, device=device + ).register(self) + + if self.anneal_frames > 0: + OfflineToOnlineAnnealHook(self, replay_buffer, self.anneal_frames).register( + self + ) From c50a08255da7b7a5e159eb90922c2225f2246ad0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jun 2026 11:12:58 -0700 Subject: [PATCH 2/4] [BugFix] Complete offline-to-online trainer wiring --- docs/source/reference/config.rst | 2 + docs/source/reference/trainers_basics.rst | 1 + .../offline_to_online/train.py | 2 +- test/test_offline_to_online.py | 37 ++++- .../trainers/algorithms/configs/__init__.py | 7 + .../trainers/algorithms/configs/trainers.py | 142 ++++++++++++++++++ .../trainers/algorithms/offline_to_online.py | 47 +++++- 7 files changed, 234 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index b2d8fcecc59..2a025aad139 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -480,6 +480,7 @@ Training and Optimization Configurations TrainerConfig PPOTrainerConfig SACTrainerConfig + OfflineToOnlineTrainerConfig DQNTrainerConfig DDPGTrainerConfig IQLTrainerConfig @@ -599,6 +600,7 @@ TorchRL currently provides configuration-driven trainers for the following algor - **PPO** (on-policy): ``PPOTrainerConfig``, ``PPOLossConfig`` - **SAC** (off-policy, continuous): ``SACTrainerConfig``, ``SACLossConfig`` +- **Offline-to-online SAC**: ``OfflineToOnlineTrainerConfig``, ``SACLossConfig`` - **DQN** (off-policy, discrete): ``DQNTrainerConfig``, ``DQNLossConfig`` - **DDPG** (off-policy, continuous): ``DDPGTrainerConfig``, ``DDPGLossConfig`` - **IQL** (offline): ``IQLTrainerConfig``, ``IQLLossConfig`` diff --git a/docs/source/reference/trainers_basics.rst b/docs/source/reference/trainers_basics.rst index 6d37464e6b5..6c64df8257e 100644 --- a/docs/source/reference/trainers_basics.rst +++ b/docs/source/reference/trainers_basics.rst @@ -26,6 +26,7 @@ Algorithm-specific trainers PPOTrainer SACTrainer + OfflineToOnlineTrainer DQNTrainer DDPGTrainer IQLTrainer diff --git a/sota-implementations/offline_to_online/train.py b/sota-implementations/offline_to_online/train.py index 07404f7e483..9a9aa323b5b 100644 --- a/sota-implementations/offline_to_online/train.py +++ b/sota-implementations/offline_to_online/train.py @@ -118,7 +118,7 @@ def main(): # Immutable offline dataset (DoubleToFloat to match the online float32 stream) # paired with a growing online buffer. - offline = load_dataset(args.dataset) + offline = load_dataset(args.dataset, batch_size=args.batch_size) offline.append_transform(DoubleToFloat()) replay_buffer = OfflineToOnlineReplayBuffer( offline_dataset=offline, diff --git a/test/test_offline_to_online.py b/test/test_offline_to_online.py index c809c14e3fa..69791cfdd83 100644 --- a/test/test_offline_to_online.py +++ b/test/test_offline_to_online.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import inspect import pytest import torch @@ -498,13 +499,21 @@ def test_state_dict_roundtrip(self): ) hook = OfflineToOnlineReplayBufferHook(rb) hook.extend(_make_online_data(20)) + rb.anneal(step=50, total_steps=100) rb2 = OfflineToOnlineReplayBuffer( - offline_dataset=_make_offline_buffer(), online_capacity=500, batch_size=16 + offline_dataset=_make_offline_buffer(), + online_capacity=500, + offline_fraction=0.8, + batch_size=16, ) hook2 = OfflineToOnlineReplayBufferHook(rb2) hook2.load_state_dict(hook.state_dict()) assert len(rb2.online_buffer) == 20 + assert rb2.offline_fraction == pytest.approx(0.25) + + rb2.anneal(step=50, total_steps=100) + assert rb2.offline_fraction == pytest.approx(0.25) class TestOfflineToOnlineAnnealHook: @@ -555,6 +564,32 @@ def test_requires_offline_to_online_buffer(self): replay_buffer=plain, ) + def test_constructor_exposes_sac_key_and_logging_kwargs(self): + from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer + + parameters = inspect.signature(OfflineToOnlineTrainer).parameters + for name in ( + "log_rewards", + "log_actions", + "log_observations", + "log_timings", + "auto_log_optim_steps", + "done_key", + "terminated_key", + "reward_key", + "episode_reward_key", + "action_key", + "observation_key", + ): + assert name in parameters + + def test_config_registered(self): + from torchrl.trainers.algorithms.configs import OfflineToOnlineTrainerConfig + + assert OfflineToOnlineTrainerConfig._target_.endswith( + "_make_offline_to_online_trainer" + ) + def test_hooks_drive_offline_online_flow(self): """The three hooks together grow the online buffer, keep the mixed batch flat, and anneal the offline fraction -- the data path the trainer runs, diff --git a/torchrl/trainers/algorithms/configs/__init__.py b/torchrl/trainers/algorithms/configs/__init__.py index 538f918e436..c0f9332ae95 100644 --- a/torchrl/trainers/algorithms/configs/__init__.py +++ b/torchrl/trainers/algorithms/configs/__init__.py @@ -123,6 +123,7 @@ DDPGTrainerConfig, DQNTrainerConfig, IQLTrainerConfig, + OfflineToOnlineTrainerConfig, PPOTrainerConfig, SACTrainerConfig, TD3TrainerConfig, @@ -397,6 +398,7 @@ "DDPGTrainerConfig", "DQNTrainerConfig", "IQLTrainerConfig", + "OfflineToOnlineTrainerConfig", "PPOTrainerConfig", "SACTrainerConfig", "TD3TrainerConfig", @@ -671,6 +673,11 @@ def _register_configs(): cs.store(group="trainer", name="ddpg", node=DDPGTrainerConfig) cs.store(group="trainer", name="dqn", node=DQNTrainerConfig) cs.store(group="trainer", name="iql", node=IQLTrainerConfig) + cs.store( + group="trainer", + name="offline_to_online", + node=OfflineToOnlineTrainerConfig, + ) cs.store(group="trainer", name="ppo", node=PPOTrainerConfig) cs.store(group="trainer", name="sac", node=SACTrainerConfig) cs.store(group="trainer", name="td3", node=TD3TrainerConfig) diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index 224c84430a6..765162b8746 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -22,6 +22,7 @@ from torchrl.trainers.algorithms.ddpg import DDPGTrainer from torchrl.trainers.algorithms.dqn import DQNTrainer from torchrl.trainers.algorithms.iql import IQLTrainer +from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer from torchrl.trainers.algorithms.ppo import PPOTrainer from torchrl.trainers.algorithms.sac import SACTrainer from torchrl.trainers.algorithms.td3 import TD3Trainer @@ -218,6 +219,147 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: return trainer +@dataclass +class OfflineToOnlineTrainerConfig(SACTrainerConfig): + """Hydra configuration for + :class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`. + + Every kwarg accepted by ``OfflineToOnlineTrainer.__init__`` is exposed as a + field here, with SAC network-construction helper fields inherited from + :class:`SACTrainerConfig`. + """ + + anneal_frames: int | None = None + + _target_: str = ( + "torchrl.trainers.algorithms.configs.trainers." + "_make_offline_to_online_trainer" + ) + + def __post_init__(self) -> None: + """Post-initialization hook for offline-to-online trainer configuration.""" + super().__post_init__() + + +def _make_offline_to_online_trainer(*args, **kwargs) -> OfflineToOnlineTrainer: + from torchrl.trainers.trainers import Logger + + collector = kwargs.pop("collector") + total_frames = kwargs.pop("total_frames") + if total_frames is None: + total_frames = collector.total_frames + frame_skip = kwargs.pop("frame_skip", 1) + optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) + loss_module = kwargs.pop("loss_module") + optimizer = kwargs.pop("optimizer") + logger = kwargs.pop("logger") + clip_grad_norm = kwargs.pop("clip_grad_norm", True) + clip_norm = kwargs.pop("clip_norm") + progress_bar = kwargs.pop("progress_bar", True) + replay_buffer = kwargs.pop("replay_buffer") + save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) + log_interval = kwargs.pop("log_interval", 10000) + save_trainer_file = kwargs.pop("save_trainer_file") + seed = kwargs.pop("seed") + actor_network = kwargs.pop("actor_network") + critic_network = kwargs.pop("critic_network") + kwargs.pop("create_env_fn") + target_net_updater = kwargs.pop("target_net_updater") + async_collection = kwargs.pop("async_collection", False) + if async_collection: + raise ValueError( + "OfflineToOnlineTrainer does not support async_collection." + ) + log_timings = kwargs.pop("log_timings", False) + auto_log_optim_steps = kwargs.pop("auto_log_optim_steps", True) + batch_size = kwargs.pop("batch_size", None) + anneal_frames = kwargs.pop("anneal_frames", None) + enable_logging = kwargs.pop("enable_logging", True) + log_rewards = kwargs.pop("log_rewards", True) + log_actions = kwargs.pop("log_actions", True) + log_observations = kwargs.pop("log_observations", False) + done_key = _normalize_hydra_key(kwargs.pop("done_key", "done")) + terminated_key = _normalize_hydra_key(kwargs.pop("terminated_key", "terminated")) + reward_key = _normalize_hydra_key(kwargs.pop("reward_key", "reward")) + episode_reward_key = _normalize_hydra_key( + kwargs.pop("episode_reward_key", "reward_sum") + ) + action_key = _normalize_hydra_key(kwargs.pop("action_key", "action")) + observation_key = _normalize_hydra_key(kwargs.pop("observation_key", "observation")) + hooks = kwargs.pop("hooks", None) + + # Instantiate networks first + if actor_network is not None and not isinstance(actor_network, torch.nn.Module): + actor_network = actor_network() + if critic_network is not None and not isinstance(critic_network, torch.nn.Module): + critic_network = critic_network() + + if not isinstance(collector, BaseCollector): + collector = collector() + + if not isinstance(loss_module, LossModule): + # then it's a partial config + loss_module = loss_module( + actor_network=actor_network, critic_network=critic_network + ) + if target_net_updater is not None and not isinstance( + target_net_updater, TargetNetUpdater + ): + # target_net_updater must be a partial taking the loss as input + target_net_updater = target_net_updater(loss_module) + if not isinstance(optimizer, torch.optim.Optimizer): + # then it's a partial config + optimizer = optimizer(params=loss_module.parameters()) + + # Quick instance checks + if not isinstance(collector, BaseCollector): + raise ValueError(f"collector must be a BaseCollector, got {type(collector)}") + if not isinstance(loss_module, LossModule): + raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") + if not isinstance(optimizer, torch.optim.Optimizer): + raise ValueError( + f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" + ) + if not isinstance(logger, Logger) and logger is not None: + raise ValueError(f"logger must be a Logger, got {type(logger)}") + + trainer = OfflineToOnlineTrainer( + collector=collector, + total_frames=total_frames, + frame_skip=frame_skip, + optim_steps_per_batch=optim_steps_per_batch, + loss_module=loss_module, + replay_buffer=replay_buffer, + anneal_frames=anneal_frames, + batch_size=batch_size, + optimizer=optimizer, + logger=logger, + clip_grad_norm=clip_grad_norm, + clip_norm=clip_norm, + progress_bar=progress_bar, + seed=seed, + save_trainer_interval=save_trainer_interval, + log_interval=log_interval, + save_trainer_file=save_trainer_file, + enable_logging=enable_logging, + log_rewards=log_rewards, + log_actions=log_actions, + log_observations=log_observations, + target_net_updater=target_net_updater, + async_collection=async_collection, + log_timings=log_timings, + auto_log_optim_steps=auto_log_optim_steps, + done_key=done_key, + terminated_key=terminated_key, + reward_key=reward_key, + episode_reward_key=episode_reward_key, + action_key=action_key, + observation_key=observation_key, + ) + _register_trainer_hooks(trainer, hooks) + return trainer + + @dataclass class PPOTrainerConfig(TrainerConfig): """Hydra configuration for :class:`~torchrl.trainers.algorithms.PPOTrainer`. diff --git a/torchrl/trainers/algorithms/offline_to_online.py b/torchrl/trainers/algorithms/offline_to_online.py index dfae7e5937c..79662b5428f 100644 --- a/torchrl/trainers/algorithms/offline_to_online.py +++ b/torchrl/trainers/algorithms/offline_to_online.py @@ -10,6 +10,8 @@ from collections.abc import Callable from tensordict import TensorDictBase +from tensordict.utils import NestedKey +from torch import optim from torchrl.collectors import BaseCollector from torchrl.data.replay_buffers.offline_to_online import OfflineToOnlineReplayBuffer @@ -89,10 +91,20 @@ def sample(self, batch: TensorDictBase) -> TensorDictBase: return sample.to(self.device) if self.device is not None else sample def state_dict(self) -> dict: - return {"online_buffer": self.replay_buffer.online_buffer.state_dict()} + return { + "online_buffer": self.replay_buffer.online_buffer.state_dict(), + "offline_fraction": self.replay_buffer._offline_fraction, + "base_offline_fraction": self.replay_buffer._base_offline_fraction, + } def load_state_dict(self, state_dict: dict) -> None: self.replay_buffer.online_buffer.load_state_dict(state_dict["online_buffer"]) + self.replay_buffer._offline_fraction = state_dict.get( + "offline_fraction", self.replay_buffer._offline_fraction + ) + self.replay_buffer._base_offline_fraction = state_dict.get( + "base_offline_fraction", self.replay_buffer._base_offline_fraction + ) def register(self, trainer, name: str = "replay_buffer") -> None: trainer.register_op("pre_epoch", self.extend) @@ -136,6 +148,10 @@ def register(self, trainer, name: str = "offline_to_online_anneal") -> None: class OfflineToOnlineTrainer(SACTrainer): """A SAC trainer for the offline-pretrain -> online-finetune transition. + See also + :class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig` + for the Hydra configuration counterpart. + Builds on :class:`~torchrl.trainers.algorithms.SACTrainer`, swapping the plain replay buffer for an :class:`~torchrl.data.OfflineToOnlineReplayBuffer`. Each collected batch is routed to the online buffer while optimization @@ -175,7 +191,7 @@ def __init__( replay_buffer: OfflineToOnlineReplayBuffer, anneal_frames: int | None = None, batch_size: int | None = None, - optimizer=None, + optimizer: optim.Optimizer | None = None, logger: Logger | None = None, clip_grad_norm: bool = True, clip_norm: float | None = None, @@ -185,13 +201,29 @@ def __init__( log_interval: int = 10000, save_trainer_file: str | pathlib.Path | None = None, enable_logging: bool = True, + log_rewards: bool = True, + log_actions: bool = True, + log_observations: bool = False, target_net_updater: TargetNetUpdater | None = None, + async_collection: bool = False, + log_timings: bool = False, + auto_log_optim_steps: bool = True, + done_key: NestedKey = "done", + terminated_key: NestedKey = "terminated", + reward_key: NestedKey = "reward", + episode_reward_key: NestedKey = "reward_sum", + action_key: NestedKey = "action", + observation_key: NestedKey = "observation", ) -> None: if not isinstance(replay_buffer, OfflineToOnlineReplayBuffer): raise TypeError( "OfflineToOnlineTrainer requires an OfflineToOnlineReplayBuffer, " f"got {type(replay_buffer).__name__}." ) + if async_collection: + raise ValueError( + "OfflineToOnlineTrainer does not support async_collection." + ) # Let SACTrainer wire up everything except the replay buffer (its # ReplayBufferTrainer assumes a sampler/priority API the offline-to-online @@ -213,8 +245,19 @@ def __init__( save_trainer_file=save_trainer_file, replay_buffer=None, enable_logging=enable_logging, + log_rewards=log_rewards, + log_actions=log_actions, + log_observations=log_observations, target_net_updater=target_net_updater, async_collection=False, + log_timings=log_timings, + auto_log_optim_steps=auto_log_optim_steps, + done_key=done_key, + terminated_key=terminated_key, + reward_key=reward_key, + episode_reward_key=episode_reward_key, + action_key=action_key, + observation_key=observation_key, ) self.replay_buffer = replay_buffer From 377437b92c75881cc5cdbb7835cc0324710dc5fc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jun 2026 15:36:52 -0700 Subject: [PATCH 3/4] [BugFix] Fix offline-to-online CI failures --- .../linux_libs/scripts_ataridqn/install.sh | 12 ++++++++---- .../linux_libs/scripts_gen-dgrl/install.sh | 12 ++++++++---- .../unittest/linux_libs/scripts_openx/install.sh | 12 ++++++++---- .../unittest/linux_libs/scripts_vd4rl/install.sh | 12 ++++++++---- test/test_offline_to_online.py | 15 ++++++++++++--- torchrl/trainers/algorithms/configs/trainers.py | 7 ++----- torchrl/trainers/algorithms/offline_to_online.py | 3 +-- 7 files changed, 47 insertions(+), 26 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_ataridqn/install.sh b/.github/unittest/linux_libs/scripts_ataridqn/install.sh index aa37674e7e9..a18c3f3f989 100755 --- a/.github/unittest/linux_libs/scripts_ataridqn/install.sh +++ b/.github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu128" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh index aa37674e7e9..a18c3f3f989 100755 --- a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh +++ b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh @@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu128" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index bea627bcb36..09ca2902b4a 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu128" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu -U --no-deps else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U + pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 -U --no-deps fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index bea627bcb36..09ca2902b4a 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu128" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu -U --no-deps else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U + pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 -U --no-deps fi else printf "Failed to install pytorch" diff --git a/test/test_offline_to_online.py b/test/test_offline_to_online.py index 69791cfdd83..3394284a652 100644 --- a/test/test_offline_to_online.py +++ b/test/test_offline_to_online.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import importlib.util import inspect import pytest @@ -16,13 +17,17 @@ from torchrl.data.datasets.utils import load_dataset, register_dataset from torchrl.data.replay_buffers.offline_to_online import prefill_replay_buffer from torchrl.envs.libs.gym import _has_gym +from torchrl.testing.gym_helpers import PENDULUM_VERSIONED # Running a SAC loss requires a tensordict new enough to support # ``to_module(preserve_module_state=...)``; the offline-to-online wiring itself # does not. _LOSS_RUNNABLE = ( - "preserve_module_state" - in __import__("inspect").signature(TensorDict.to_module).parameters + "preserve_module_state" in inspect.signature(TensorDict.to_module).parameters +) +_CONFIGS_AVAILABLE = ( + importlib.util.find_spec("hydra") is not None + and importlib.util.find_spec("omegaconf") is not None ) @@ -583,6 +588,10 @@ def test_constructor_exposes_sac_key_and_logging_kwargs(self): ): assert name in parameters + @pytest.mark.skipif( + not _CONFIGS_AVAILABLE, + reason="Config system requires hydra-core and omegaconf", + ) def test_config_registered(self): from torchrl.trainers.algorithms.configs import OfflineToOnlineTrainerConfig @@ -637,7 +646,7 @@ def test_train_grows_online_and_anneals(self, tmp_path): from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer torch.manual_seed(0) - env = GymEnv("Pendulum-v1") + env = GymEnv(PENDULUM_VERSIONED()) obs_dim = env.observation_spec["observation"].shape[-1] action_dim = env.action_spec.shape[-1] diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index 765162b8746..ff939304ba9 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -221,8 +221,7 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: @dataclass class OfflineToOnlineTrainerConfig(SACTrainerConfig): - """Hydra configuration for - :class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`. + """Hydra configuration for :class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`. Every kwarg accepted by ``OfflineToOnlineTrainer.__init__`` is exposed as a field here, with SAC network-construction helper fields inherited from @@ -267,9 +266,7 @@ def _make_offline_to_online_trainer(*args, **kwargs) -> OfflineToOnlineTrainer: target_net_updater = kwargs.pop("target_net_updater") async_collection = kwargs.pop("async_collection", False) if async_collection: - raise ValueError( - "OfflineToOnlineTrainer does not support async_collection." - ) + raise ValueError("OfflineToOnlineTrainer does not support async_collection.") log_timings = kwargs.pop("log_timings", False) auto_log_optim_steps = kwargs.pop("auto_log_optim_steps", True) batch_size = kwargs.pop("batch_size", None) diff --git a/torchrl/trainers/algorithms/offline_to_online.py b/torchrl/trainers/algorithms/offline_to_online.py index 79662b5428f..0a5d1fb289f 100644 --- a/torchrl/trainers/algorithms/offline_to_online.py +++ b/torchrl/trainers/algorithms/offline_to_online.py @@ -148,8 +148,7 @@ def register(self, trainer, name: str = "offline_to_online_anneal") -> None: class OfflineToOnlineTrainer(SACTrainer): """A SAC trainer for the offline-pretrain -> online-finetune transition. - See also - :class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig` + See also :class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig` for the Hydra configuration counterpart. Builds on :class:`~torchrl.trainers.algorithms.SACTrainer`, swapping the From 6651edb6ffef4b871725ce9ab5988498b1c9efae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Jun 2026 16:06:08 -0700 Subject: [PATCH 4/4] [CI] Install Pillow for torchvision libs jobs --- .github/unittest/linux_libs/scripts_ataridqn/install.sh | 2 ++ .github/unittest/linux_libs/scripts_gen-dgrl/install.sh | 2 ++ .github/unittest/linux_libs/scripts_openx/install.sh | 2 ++ .github/unittest/linux_libs/scripts_vd4rl/install.sh | 2 ++ 4 files changed, 8 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_ataridqn/install.sh b/.github/unittest/linux_libs/scripts_ataridqn/install.sh index a18c3f3f989..37a3e3c8dac 100755 --- a/.github/unittest/linux_libs/scripts_ataridqn/install.sh +++ b/.github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -48,6 +48,8 @@ else exit 1 fi +pip3 install pillow + # install tensordict if [[ "$RELEASE" == 0 ]]; then pip3 install git+https://github.com/pytorch/tensordict.git diff --git a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh index a18c3f3f989..37a3e3c8dac 100755 --- a/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh +++ b/.github/unittest/linux_libs/scripts_gen-dgrl/install.sh @@ -48,6 +48,8 @@ else exit 1 fi +pip3 install pillow + # install tensordict if [[ "$RELEASE" == 0 ]]; then pip3 install git+https://github.com/pytorch/tensordict.git diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index 09ca2902b4a..876bb96b386 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -48,6 +48,8 @@ else exit 1 fi +pip3 install pillow + # install tensordict if [[ "$RELEASE" == 0 ]]; then pip3 install git+https://github.com/pytorch/tensordict.git diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 09ca2902b4a..876bb96b386 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -48,6 +48,8 @@ else exit 1 fi +pip3 install pillow + # install tensordict if [[ "$RELEASE" == 0 ]]; then pip3 install git+https://github.com/pytorch/tensordict.git