diff --git a/test/objectives/test_loss_module.py b/test/objectives/test_loss_module.py index cbdc78fb4ef..ae929c48e43 100644 --- a/test/objectives/test_loss_module.py +++ b/test/objectives/test_loss_module.py @@ -689,19 +689,11 @@ def __init__(self): super().__init__() loss_module = MyLoss() - with pytest.raises(AttributeError): - loss_module.set_keys() - - class MyLoss2(MyLoss): - def _forward_value_estimator_keys(self, **kwargs) -> None: - pass - - loss_module = MyLoss2() assert loss_module.set_keys() is None with pytest.raises(ValueError): loss_module.set_keys(some_key="test") - class MyLoss3(MyLoss2): + class MyLoss3(MyLoss): @dataclass class _AcceptedKeys: some_key: str = "some_value" @@ -1742,6 +1734,192 @@ def test_td3_gae_raises(self, device): loss.make_value_estimator(ValueEstimators.GAE) +# --------------------------------------------------------------------------- +# LossModule.register_coeff_buffer +# --------------------------------------------------------------------------- + + +class TestRegisterCoeffBuffer: + """Regression tests for LossModule.register_coeff_buffer.""" + + class _CoeffLoss(LossModule): + def forward(self, td): + return td + + def test_scalar_registers_buffer(self): + loss = self._CoeffLoss() + loss.register_coeff_buffer("entropy_coeff", 0.01) + assert isinstance(loss.entropy_coeff, torch.Tensor) + assert "entropy_coeff" in dict(loss.named_buffers()) + assert torch.isclose(loss.entropy_coeff, torch.tensor(0.01)) + + def test_none_sets_attribute_not_buffer(self): + loss = self._CoeffLoss() + loss.register_coeff_buffer("critic_coeff", None) + assert loss.critic_coeff is None + assert "critic_coeff" not in dict(loss.named_buffers()) + + def test_tensor_passthrough(self): + loss = self._CoeffLoss() + loss.register_coeff_buffer("c", torch.tensor(2.0)) + assert "c" in dict(loss.named_buffers()) + assert torch.isclose(loss.c, torch.tensor(2.0)) + + def test_dtype_is_respected(self): + loss = self._CoeffLoss() + loss.register_coeff_buffer("c", 1.0, dtype=torch.float64) + assert loss.c.dtype == torch.float64 + + def test_non_scalar_rejected(self): + loss = self._CoeffLoss() + with pytest.raises(ValueError, match="c must be a float or a scalar tensor"): + loss.register_coeff_buffer("c", torch.ones(2)) + + def test_bool_rejected(self): + loss = self._CoeffLoss() + with pytest.raises(ValueError, match="c must be a float or a scalar tensor"): + loss.register_coeff_buffer("c", True) + + def test_a2c_registers_coeff_buffers(self): + actor = ProbabilisticActor( + module=TensorDictModule( + nn.Linear(4, 4), in_keys=["observation"], out_keys=["loc", "scale"] + ), + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + spec=Composite(action=Bounded(-1, 1, (2,))), + ) + critic = ValueOperator(nn.Linear(4, 1), in_keys=["observation"]) + loss = A2CLoss(actor, critic) + buffers = dict(loss.named_buffers()) + assert "entropy_coeff" in buffers + assert "critic_coeff" in buffers + + def test_a2c_none_critic_coeff(self): + actor = ProbabilisticActor( + module=TensorDictModule( + nn.Linear(4, 4), in_keys=["observation"], out_keys=["loc", "scale"] + ), + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + spec=Composite(action=Bounded(-1, 1, (2,))), + ) + critic = ValueOperator(nn.Linear(4, 1), in_keys=["observation"]) + loss = A2CLoss(actor, critic, critic_coeff=None) + assert loss.critic_coeff is None + assert "critic_coeff" not in dict(loss.named_buffers()) + + +# --------------------------------------------------------------------------- +# LossModule._forward_value_estimator_keys (default implementation) +# --------------------------------------------------------------------------- + + +class _RecordingValueEstimator: + """Minimal value-estimator stub that records set_keys calls.""" + + def __init__(self): + self.received = None + + def set_keys(self, **kwargs): + self.received = kwargs + + +@dataclass +class _VEAcceptedKeys: + value: str = "state_value" + reward: str = "reward" + done: str = "done" + terminated: str = "terminated" + + +class _ValueKeysLoss(LossModule): + """Loss with value-style keys that relies on the default forwarding.""" + + _AcceptedKeys = _VEAcceptedKeys + default_keys = _VEAcceptedKeys + default_value_estimator = ValueEstimators.TD0 + + def forward(self, td): + return td + + +class TestDefaultForwardValueEstimatorKeys: + """Regression tests for the default LossModule._forward_value_estimator_keys.""" + + def test_forwards_present_keys(self): + loss = _ValueKeysLoss() + rec = _RecordingValueEstimator() + loss._value_estimator = rec + loss._forward_value_estimator_keys() + assert rec.received == { + "value": "state_value", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + def test_absent_keys_are_not_forwarded(self): + loss = _ValueKeysLoss() + rec = _RecordingValueEstimator() + loss._value_estimator = rec + loss._forward_value_estimator_keys() + # advantage / value_target / sample_log_prob are not on this loss's keys + assert "advantage" not in rec.received + assert "value_target" not in rec.received + assert "sample_log_prob" not in rec.received + + def test_set_keys_propagates_to_estimator(self): + loss = _ValueKeysLoss() + rec = _RecordingValueEstimator() + loss._value_estimator = rec + loss.set_keys(value="my_state_value") + assert loss.tensor_keys.value == "my_state_value" + assert rec.received["value"] == "my_state_value" + + def test_no_value_estimator_is_noop(self): + loss = _ValueKeysLoss() + loss._value_estimator = None + loss._forward_value_estimator_keys() # must not raise + + def test_set_in_keys_called_when_present(self): + calls = [] + + class _L(LossModule): + _AcceptedKeys = _VEAcceptedKeys + default_keys = _VEAcceptedKeys + + def forward(self, td): + return td + + def _set_in_keys(self): + calls.append(True) + + loss = _L() + loss._value_estimator = None + loss._forward_value_estimator_keys() + assert calls == [True] + + def test_sac_set_keys_forwards_via_default(self): + actor = ProbabilisticActor( + module=TensorDictModule( + nn.Linear(4, 4), in_keys=["observation"], out_keys=["loc", "scale"] + ), + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + spec=Composite(action=Bounded(-1, 1, (2,))), + ) + qvalue = TensorDictModule( + nn.Linear(6, 1), + in_keys=["observation", "action"], + out_keys=["state_action_value"], + ) + loss = SACLoss(actor, qvalue) + loss.make_value_estimator(ValueEstimators.TD0) + loss.set_keys(value="my_state_value") + assert loss.value_estimator.tensor_keys.value == "my_state_value" + + class TestValueEstimatorRegistry: """Tests for the dynamic value-estimator registry. diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index e8089ce2fd1..19b64e6111d 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -346,37 +346,14 @@ def __init__( device = _get_default_device(self) - self.register_buffer( - "entropy_coeff", torch.as_tensor(entropy_coeff, device=device) - ) - if critic_coeff is not None: - self.register_buffer( - "critic_coeff", torch.as_tensor(critic_coeff, device=device) - ) - else: - self.critic_coeff = None + self.register_coeff_buffer("entropy_coeff", entropy_coeff, device=device) + self.register_coeff_buffer("critic_coeff", critic_coeff, device=device) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type - if clip_value is not None: - if isinstance(clip_value, float): - clip_value = torch.tensor(clip_value) - elif isinstance(clip_value, torch.Tensor): - if clip_value.numel() != 1: - raise ValueError( - f"clip_value must be a float or a scalar tensor, got {clip_value}." - ) - else: - raise ValueError( - f"clip_value must be a float or a scalar tensor, got {clip_value}." - ) - self.register_buffer( - "clip_value", torch.as_tensor(clip_value, device=device) - ) - else: - self.clip_value = None + self.register_coeff_buffer("clip_value", clip_value, device=device) log_prob_keys = self.actor_network.log_prob_keys action_keys = self.actor_network.dist_sample_keys diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 261b9135151..1f32cbced5b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -97,9 +97,10 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): To utilize the ability configuring the tensordict keys via :meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass. This dataclass should include all keys that are intended to be configurable. - In addition, the subclass must implement the - :meth:._forward_value_estimator_keys() method. This function is crucial for - forwarding any altered tensordict keys to the underlying value_estimator. + The default :meth:`~._forward_value_estimator_keys()` implementation forwards + common value-estimator keys when present. Subclasses should override it when + the loss's key names need to be remapped before being forwarded to the + underlying value estimator. Subclasses can declare a ``_schedulable_buffers`` frozenset to allow direct scalar assignment (e.g. ``loss.entropy_coeff = 0.003``) for registered @@ -316,8 +317,8 @@ def set_keys(self, **kwargs) -> None: raise AttributeError( "To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module " "must define an _AcceptedKeys dataclass containing all keys intended for configuration. " - "Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to " - "facilitate forwarding of any modified tensordict keys to the underlying value_estimator." + "If the default `._forward_value_estimator_keys()` implementation is insufficient, the " + "subclass must override it to forward modified tensordict keys to the underlying value_estimator." ) from err def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -675,6 +676,82 @@ def _prepare_value_estimator_kwargs(self, value_type, **hyperparams): hp.update(hyperparams) return value_type, hp + def register_coeff_buffer( + self, + name: str, + value: float | int | torch.Tensor | None, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + """Register a scalar coefficient as a buffer, converting it to a tensor. + + Eliminates the recurring ``if not isinstance(value, Tensor): value = + torch.tensor(value); self.register_buffer(name, value)`` boilerplate in + loss ``__init__`` methods. + + If ``value`` is ``None`` the attribute is set to ``None`` instead of a + buffer being registered, matching the common optional-coefficient idiom + (e.g. ``critic_coeff`` / ``clip_value``). + + Args: + name (str): the buffer / attribute name. + value (float, int, Tensor or None): the coefficient. ``None`` sets + the attribute to ``None``. + device (torch.device, optional): device for the buffer. + dtype (torch.dtype, optional): dtype for the buffer. + """ + if value is None: + setattr(self, name, None) + return + if isinstance(value, bool) or not isinstance(value, (float, int, torch.Tensor)): + raise ValueError(f"{name} must be a float or a scalar tensor, got {value}.") + value = torch.as_tensor(value, device=device, dtype=dtype) + if value.numel() != 1: + raise ValueError(f"{name} must be a float or a scalar tensor, got {value}.") + self.register_buffer(name, value) + + # Value-estimator keys forwarded by the default + # :meth:`_forward_value_estimator_keys`. These six are accepted by every + # built-in value estimator. Keys that only some estimators accept (e.g. + # ``sample_log_prob``) are intentionally excluded; losses that forward them + # should override :meth:`_forward_value_estimator_keys`. + _value_estimator_default_keys = ( + "advantage", + "value_target", + "value", + "reward", + "done", + "terminated", + ) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + """Default forwarding of tensordict keys to the value estimator. + + Forwards every key in :attr:`_value_estimator_default_keys` that is + present on this loss's ``tensor_keys`` to the underlying value + estimator, then refreshes the loss input keys via ``_set_in_keys`` when + that method exists. + + Losses whose value-estimator key *names* differ from their own + ``tensor_keys`` names -- e.g. mapping the estimator's ``value`` to a + ``state_action_value`` / ``global_value`` key -- or that forward + estimator-specific keys such as ``sample_log_prob`` must override this + method. + """ + value_estimator = getattr(self, "_value_estimator", None) + if value_estimator is not None: + keys = { + name: getattr(self.tensor_keys, name) + for name in self._value_estimator_default_keys + if hasattr(self.tensor_keys, name) + } + if keys: + value_estimator.set_keys(**keys) + set_in_keys = getattr(self, "_set_in_keys", None) + if callable(set_in_keys): + set_in_keys() + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): """Value-function constructor. diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index ca1a84b0606..f8438ecbd6b 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -386,16 +386,6 @@ def loss_value_diff(diff, expectile=0.8): weight = torch.where(diff > 0, expectile, (1 - expectile)) return weight * (diff**2) - def _forward_value_estimator_keys(self, **kwargs) -> None: - if self._value_estimator is not None: - self._value_estimator.set_keys( - value=self._tensor_keys.value, - reward=self.tensor_keys.reward, - done=self.tensor_keys.done, - terminated=self.tensor_keys.terminated, - ) - self._set_in_keys() - @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, metadata = self.actor_loss(tensordict) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index dd9352fd971..a6c1265f3ff 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -566,24 +566,13 @@ def __init__( # Store the mapping for per-head coefficients self._entropy_coeff_map = {k: float(v) for k, v in entropy_coeff.items()} # Register an empty buffer for compatibility - self.register_buffer("entropy_coeff", torch.tensor(0.0)) + self.register_coeff_buffer("entropy_coeff", 0.0, device=device) elif isinstance(entropy_coeff, (float, int, torch.Tensor)): - # Register the scalar entropy coefficient - coeff = ( - float(entropy_coeff) - if not torch.is_tensor(entropy_coeff) - else float(entropy_coeff.item()) - ) - self.register_buffer("entropy_coeff", torch.tensor(coeff)) + self.register_coeff_buffer("entropy_coeff", entropy_coeff, device=device) self._entropy_coeff_map = None else: raise TypeError("entropy_coeff must be a float or a Mapping[str, float]") - if critic_coeff is not None: - self.register_buffer( - "critic_coeff", torch.tensor(critic_coeff, device=device) - ) - else: - self.critic_coeff = None + self.register_coeff_buffer("critic_coeff", critic_coeff, device=device) self._has_critic = bool(self.critic_coeff is not None and self.critic_coeff > 0) self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage @@ -597,21 +586,7 @@ def __init__( value=value_key, ) - if clip_value is not None: - if isinstance(clip_value, float): - clip_value = torch.tensor(clip_value, device=device) - elif isinstance(clip_value, torch.Tensor): - if clip_value.numel() != 1: - raise ValueError( - f"clip_value must be a float or a scalar tensor, got {clip_value}." - ) - else: - raise ValueError( - f"clip_value must be a float or a scalar tensor, got {clip_value}." - ) - self.register_buffer("clip_value", clip_value.to(device)) - else: - self.clip_value = None + self.register_coeff_buffer("clip_value", clip_value, device=device) try: log_prob_keys = self.actor_network.log_prob_keys action_keys = self.actor_network.dist_sample_keys