Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion test/objectives/test_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LossModuleTestBase,
)

from tensordict import assert_allclose_td, TensorDict
from tensordict import assert_allclose_td, tensorclass, TensorDict
from tensordict.nn import (
NormalParamExtractor,
ProbabilisticTensorDictModule as ProbMod,
Expand Down Expand Up @@ -281,6 +281,21 @@ def test_reset_parameters_recursive(self):
)
self.reset_parameters_recursive_test(loss_fn)

@staticmethod
def _as_iql_tensorclass(td):
@tensorclass
class MyData:
observation: torch.Tensor
action: torch.Tensor
next: TensorDict # noqa: A003
td_error: torch.Tensor | None = None
_log_prob: torch.Tensor | None = None
loc: torch.Tensor | None = None
scale: torch.Tensor | None = None

return MyData(**td, batch_size=td.batch_size)

@pytest.mark.parametrize("as_tensorclass", [False, True])
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
Expand All @@ -293,9 +308,12 @@ def test_iql(
temperature,
expectile,
td_est,
as_tensorclass,
):
torch.manual_seed(self.seed)
td = self._create_mock_data_iql(device=device)
if as_tensorclass:
td = self._as_iql_tensorclass(td)

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
Expand Down Expand Up @@ -920,6 +938,32 @@ def test_iql_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_value == loss_val_td["loss_value"]

def test_iql_tensorclass_parity(self):
torch.manual_seed(self.seed)
td = self._create_mock_data_iql()

actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
value = self._create_mock_value()

loss_fn = IQLLoss(
actor_network=actor, qvalue_network=qvalue, value_network=value
)
data = self._as_iql_tensorclass(td)

with pytest.warns(
UserWarning, match="No target network updater"
) if rl_warnings() else contextlib.nullcontext():
torch.manual_seed(self.seed)
loss_val_tc = loss_fn(data)
torch.manual_seed(self.seed)
loss_val_td = loss_fn(td)
assert_allclose_td(loss_val_td, loss_val_tc)
torch.testing.assert_close(
data.get(loss_fn.tensor_keys.priority),
td.get(loss_fn.tensor_keys.priority),
)

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_iql_reduction(self, reduction):
torch.manual_seed(self.seed)
Expand Down
25 changes: 16 additions & 9 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_ERROR,
_make_writable,
_pseudo_vmap,
_reduce,
_vmap_func,
Expand Down Expand Up @@ -450,10 +451,12 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
):
dist = self.actor_network.get_dist(tensordict)

log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))

# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = _make_writable(
tensordict.select(*self.qvalue_network.in_keys, strict=False)
)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)

Expand All @@ -463,8 +466,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
)
# state value
with torch.no_grad():
td_copy = tensordict.select(
*self.value_network.in_keys, strict=False
td_copy = _make_writable(
tensordict.select(*self.value_network.in_keys, strict=False)
).detach()
with self.value_network_params.to_module(
self.value_network, preserve_module_state=False
Expand Down Expand Up @@ -494,11 +497,15 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:

def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = _make_writable(
tensordict.select(*self.qvalue_network.in_keys, strict=False)
)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
# state value
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
td_copy = _make_writable(
tensordict.select(*self.value_network.in_keys, strict=False)
)
with self.value_network_params.to_module(
self.value_network, preserve_module_state=False
):
Expand All @@ -519,8 +526,8 @@ def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:

def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
obs_keys = self.actor_network.in_keys
tensordict = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
tensordict = _make_writable(
tensordict.select("next", *obs_keys, self.tensor_keys.action, strict=False)
)

target_value = self.value_estimator.value_estimate(
Expand Down Expand Up @@ -858,7 +865,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
):
dist = self.actor_network.get_dist(tensordict)

log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))

# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
Expand Down
20 changes: 19 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from typing import Any, TypeVar

import torch
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict import (
is_tensorclass,
NestedKey,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.nn import TensorDictModule
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -952,6 +958,18 @@ def _get_default_device(net):
return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()


def _make_writable(td: TensorDictBase) -> TensorDictBase:
"""Returns a container that accepts new keys, for use as network scratch.

Networks write their ``out_keys`` into the tensordict they run on. A
tensorclass has a fixed schema and rejects keys that were not declared as
fields, so it is converted to a plain :class:`~tensordict.TensorDict`.
Dynamic containers (``TensorDict``, lazy stacks) already accept new keys and
are returned unchanged to avoid a needless clone on the hot path.
"""
return td.to_tensordict() if is_tensorclass(td) else td


def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""Groups multiple optimizers into a single one.

Expand Down
Loading