diff --git a/test/objectives/test_iql.py b/test/objectives/test_iql.py index d1ec2ae4f6e..a92c19c3323 100644 --- a/test/objectives/test_iql.py +++ b/test/objectives/test_iql.py @@ -18,7 +18,7 @@ LossModuleTestBase, ) -from tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, tensorclass, TensorDict from tensordict.nn import ( NormalParamExtractor, ProbabilisticTensorDictModule as ProbMod, @@ -281,6 +281,21 @@ def test_reset_parameters_recursive(self): ) self.reset_parameters_recursive_test(loss_fn) + @staticmethod + def _as_iql_tensorclass(td): + @tensorclass + class MyData: + observation: torch.Tensor + action: torch.Tensor + next: TensorDict # noqa: A003 + td_error: torch.Tensor | None = None + _log_prob: torch.Tensor | None = None + loc: torch.Tensor | None = None + scale: torch.Tensor | None = None + + return MyData(**td, batch_size=td.batch_size) + + @pytest.mark.parametrize("as_tensorclass", [False, True]) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0]) @@ -293,9 +308,12 @@ def test_iql( temperature, expectile, td_est, + as_tensorclass, ): torch.manual_seed(self.seed) td = self._create_mock_data_iql(device=device) + if as_tensorclass: + td = self._as_iql_tensorclass(td) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) @@ -920,6 +938,32 @@ def test_iql_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_value == loss_val_td["loss_value"] + def test_iql_tensorclass_parity(self): + torch.manual_seed(self.seed) + td = self._create_mock_data_iql() + + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = self._create_mock_value() + + loss_fn = IQLLoss( + actor_network=actor, qvalue_network=qvalue, value_network=value + ) + data = self._as_iql_tensorclass(td) + + with pytest.warns( + UserWarning, match="No target network updater" + ) if rl_warnings() else contextlib.nullcontext(): + torch.manual_seed(self.seed) + loss_val_tc = loss_fn(data) + torch.manual_seed(self.seed) + loss_val_td = loss_fn(td) + assert_allclose_td(loss_val_td, loss_val_tc) + torch.testing.assert_close( + data.get(loss_fn.tensor_keys.priority), + td.get(loss_fn.tensor_keys.priority), + ) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_iql_reduction(self, reduction): torch.manual_seed(self.seed) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index f5c91eec592..14d873e1ef8 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -18,6 +18,7 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, + _make_writable, _pseudo_vmap, _reduce, _vmap_func, @@ -450,10 +451,12 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: ): dist = self.actor_network.get_dist(tensordict) - log_prob = dist.log_prob(tensordict[self.tensor_keys.action]) + log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action)) # Min Q value - td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + td_q = _make_writable( + tensordict.select(*self.qvalue_network.in_keys, strict=False) + ) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -463,8 +466,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: ) # state value with torch.no_grad(): - td_copy = tensordict.select( - *self.value_network.in_keys, strict=False + td_copy = _make_writable( + tensordict.select(*self.value_network.in_keys, strict=False) ).detach() with self.value_network_params.to_module( self.value_network, preserve_module_state=False @@ -494,11 +497,15 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: # Min Q value - td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + td_q = _make_writable( + tensordict.select(*self.qvalue_network.in_keys, strict=False) + ) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) # state value - td_copy = tensordict.select(*self.value_network.in_keys, strict=False) + td_copy = _make_writable( + tensordict.select(*self.value_network.in_keys, strict=False) + ) with self.value_network_params.to_module( self.value_network, preserve_module_state=False ): @@ -519,8 +526,8 @@ def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys - tensordict = tensordict.select( - "next", *obs_keys, self.tensor_keys.action, strict=False + tensordict = _make_writable( + tensordict.select("next", *obs_keys, self.tensor_keys.action, strict=False) ) target_value = self.value_estimator.value_estimate( @@ -858,7 +865,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]: ): dist = self.actor_network.get_dist(tensordict) - log_prob = dist.log_prob(tensordict[self.tensor_keys.action]) + log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action)) # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index efa82f57c04..4a913350cec 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -14,7 +14,13 @@ from typing import Any, TypeVar import torch -from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key +from tensordict import ( + is_tensorclass, + NestedKey, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -952,6 +958,18 @@ def _get_default_device(net): return getattr(torch, "get_default_device", lambda: torch.device("cpu"))() +def _make_writable(td: TensorDictBase) -> TensorDictBase: + """Returns a container that accepts new keys, for use as network scratch. + + Networks write their ``out_keys`` into the tensordict they run on. A + tensorclass has a fixed schema and rejects keys that were not declared as + fields, so it is converted to a plain :class:`~tensordict.TensorDict`. + Dynamic containers (``TensorDict``, lazy stacks) already accept new keys and + are returned unchanged to avoid a needless clone on the hot path. + """ + return td.to_tensordict() if is_tensorclass(td) else td + + def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: """Groups multiple optimizers into a single one.