Skip to content

[Feature] Tensorclass support for IQLLoss#3864

Open
aehebald wants to merge 1 commit into
pytorch:mainfrom
aehebald:tensorclass-iql-loss
Open

[Feature] Tensorclass support for IQLLoss#3864
aehebald wants to merge 1 commit into
pytorch:mainfrom
aehebald:tensorclass-iql-loss

Conversation

@aehebald

Copy link
Copy Markdown

Description

Makes IQLLoss/DiscreteIQLLoss accept tensorclass inputs, not just TensorDict.
First loss converted as a template for the rest of #1062.

  • Read keys with .get() instead of td[key] ([] is positional indexing on a tensorclass).
  • Add _make_writable() in objectives/utils.py; route network scratch selections through it
    (a tensorclass rejects undeclared out_keys, so convert to TensorDict; dynamic containers pass through).
  • Parametrize test_iql over a tensorclass input + add a tensorclass/TensorDict parity test.

Motivation and Context

#1062 asks for tensorclass support across all losses. This does IQL first so the
approach can be reviewed before applying it to the rest.

Part of #1062

  • I have raised an issue to propose this change

Types of changes

  • New feature (non-breaking change which adds core functionality)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@pytorch-bot

pytorch-bot Bot commented Jun 13, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3864

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 1 Unrelated Failure

As of commit 089fef6 with merge base ad8ea7f (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 13, 2026
@github-actions

github-actions Bot commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Benchmark Results: PR 089fef66 vs main ad8ea7fb

Benchmark run: https://github.com/pytorch/rl/actions/runs/27471173780

Higher ops/sec is better. Tables are sorted by largest absolute change.

CPU

Compared 192 benchmarks. Regressions over 5%: 10. Improvements over 5%: 26.

Benchmark main ops PR ops Change
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 50.83 191.52 +276.78%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 195.99 39.34 -79.93%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 2,535 3,467 +36.77%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 2,690 3,659 +36.03%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 2,613 3,502 +34.05%
benchmarks/test_objectives_benchmarks.py::test_values[vec_generalized_advantage_estimate-True-True] 79.00 55.04 -30.33%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 2,484 3,030 +21.97%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1,809 2,201 +21.62%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[True-backward] 208.89 249.98 +19.67%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 650.75 776.48 +19.32%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 695.87 829.50 +19.20%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1,037 841.00 -18.91%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3,112 3,652 +17.33%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[4-same] 22.93 26.65 +16.25%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1,861 2,147 +15.38%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[True-backward] 862.74 984.08 +14.06%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[16-same] 21.50 18.62 -13.40%
benchmarks/test_objectives_benchmarks.py::test_values[vec_td1_return_estimate-False-False] 62.85 54.71 -12.95%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[True-backward] 365.59 412.07 +12.71%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 498.90 560.39 +12.33%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[True-backward] 121.39 136.11 +12.13%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[reduce-overhead-None] 1,661 1,854 +11.62%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 2,042 1,819 -10.93%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-False-True] 34,942 38,275 +9.54%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-True-False-True] 39,539 42,950 +8.63%
benchmarks/test_objectives_benchmarks.py::test_values[vec_td_lambda_return_estimate-True-False] 59.12 54.66 -7.54%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-False-True] 35,727 38,367 +7.39%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 2,119 1,964 -7.33%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[safetensors] 22,652 24,186 +6.77%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[False-None] 321.53 343.22 +6.75%
benchmarks/test_non_tensor_env_benchmark.py::test_non_tensor_env_rollout_speed[1000-single-True] 1.3748 1.2838 -6.62%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 3,016 2,823 -6.38%
benchmarks/test_envs_benchmark.py::test_simple 1.7076 1.8154 +6.31%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[True-None] 266.71 282.51 +5.92%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 912.33 964.31 +5.70%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3,034 3,196 +5.36%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[100-img_shape1-atari] 699.39 733.24 +4.84%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[True-backward] 109.00 114.18 +4.75%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-True-True] 20,443 21,358 +4.48%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[True-backward] 114.39 119.48 +4.45%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-scan-False-0-lstm] 1.9716 2.0572 +4.34%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 3,321 3,463 +4.28%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-cudnn-True-0-gru] 1.4972 1.4343 -4.20%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-scan-False-0-gru] 2.9070 3.0253 +4.07%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 565.88 544.75 -3.73%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[True-None] 283.10 293.24 +3.58%
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] 49.44 47.67 -3.58%
benchmarks/test_envs_benchmark.py::test_transformed 0.8819 0.9132 +3.55%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[200-img_shape3-large_batch] 785.27 758.19 -3.45%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[100-img_shape1-atari] 5,157 5,334 +3.43%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_lazystack_then_write[100-img_shape1-atari] 643.62 665.15 +3.34%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-False-False] 63,375 65,465 +3.30%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[reduce-overhead-None] 463.94 477.86 +3.00%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-True-True-False] 42,908 44,194 +3.00%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-True-True] 22,174 22,837 +2.99%
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 24.07 24.79 +2.98%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-True-False-False] 50,476 51,944 +2.91%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[True-None] 1,757 1,807 +2.88%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-True-False-False] 58,308 59,972 +2.85%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-True-True-True] 23,828 24,505 +2.84%
benchmarks/test_non_tensor_env_benchmark.py::test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-True] 0.2066 0.2123 +2.76%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[16-constant] 2,535 2,466 -2.73%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_stack_then_write[100-img_shape1-atari] 277.31 269.74 -2.73%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-False-False-True] 30,877 31,697 +2.66%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_lazystack_then_write[50-img_shape0-small] 3,529 3,437 -2.59%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[200-img_shape3-large_batch] 332.64 341.13 +2.55%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_lazystack_then_write[200-img_shape3-large_batch] 307.19 314.90 +2.51%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[reduce-overhead-None] 280.39 287.27 +2.45%
benchmarks/test_envs_benchmark.py::test_serial 0.5727 0.5864 +2.41%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 655.46 640.41 -2.30%
benchmarks/test_objectives_benchmarks.py::test_values[td0_return_estimate-False-False] 7,339 7,505 +2.27%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-True-True] 18,418 18,819 +2.17%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-scan-True-0-lstm] 3.0926 3.1591 +2.15%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-scan-True-0-gru] 4.1510 4.2385 +2.11%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-False-False] 63,736 65,060 +2.08%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_without_rb[200-img_shape1-large_batch] 14.89 15.20 +2.07%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-False-True-True] 19,926 20,335 +2.05%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-True-True-True] 19,030 19,413 +2.01%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[False-None] 87.85 89.61 +2.00%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_stack_then_write[200-img_shape3-large_batch] 138.60 141.36 +1.99%
benchmarks/test_objectives_benchmarks.py::test_iql_speed[reduce-overhead-None] 115.21 117.50 +1.98%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[False-backward] 61.73 62.95 +1.97%
benchmarks/test_replaybuffer_benchmark.py::TestPrioritizedReplayBufferBenchmark::test_sampler_sample_scale[1000000-cpu] 96.33 98.22 +1.96%
benchmarks/test_envs_benchmark.py::test_parallel 0.9790 0.9599 -1.94%
benchmarks/test_non_tensor_env_benchmark.py::test_non_tensor_env_rollout_speed[1000-parallel-buffers-True] 0.5457 0.5352 -1.93%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-cudnn-True-0-lstm] 0.9687 0.9502 -1.92%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-False-True] 29,146 29,692 +1.87%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[False-None] 159.75 162.71 +1.85%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[torch.save] 7,040 7,170 +1.85%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[50-img_shape0-small] 4,426 4,507 +1.82%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-True-True-False] 29,653 30,193 +1.82%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 826.95 812.13 -1.79%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-cudnn-False-0-gru] 1.3779 1.3536 -1.77%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-True-False] 35,289 35,881 +1.68%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[generalized_advantage_estimate-False-1-512] 104.43 106.16 +1.66%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[False-backward] 87.78 89.22 +1.64%
benchmarks/test_objectives_benchmarks.py::test_cql_speed[True-backward] 58.85 57.89 -1.64%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[True-None] 261.55 265.80 +1.63%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 167.87 170.58 +1.62%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-True-True-False] 35,969 35,390 -1.61%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-True-False-True] 31,147 31,637 +1.57%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[reduce-overhead-None] 263.83 267.85 +1.53%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-True-False] 27,680 28,088 +1.48%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_stack_then_write[100-img_shape2-large_img] 174.04 176.60 +1.47%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[untyped_storage] 7.8168 7.9299 +1.45%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 162.15 164.48 +1.43%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 548.74 556.52 +1.42%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[100-img_shape2-large_img] 572.07 564.01 -1.41%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[reduce-overhead-None] 285.76 289.79 +1.41%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[4-constant] 4,204 4,145 -1.41%
benchmarks/test_collectors_benchmark.py::test_async 17.51 17.75 +1.37%
benchmarks/test_objectives_benchmarks.py::test_redq_speed[False-backward] 57.50 56.78 -1.26%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-False-False] 45,522 46,095 +1.26%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-False-False-True] 34,871 35,307 +1.25%
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 53.06 52.40 -1.25%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 167.47 169.50 +1.21%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[False-backward] 84.20 83.18 -1.20%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[False-None] 122.50 123.96 +1.20%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_without_rb[100-img_shape0-atari] 29.49 29.84 +1.18%
benchmarks/test_objectives_benchmarks.py::test_td3_speed[reduce-overhead-None] 567.94 574.40 +1.14%
... ... ... Showing 120 of 192 comparisons, sorted by absolute change.

GPU

Compared 202 benchmarks. Regressions over 5%: 14. Improvements over 5%: 11.

Benchmark main ops PR ops Change
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 44.65 198.43 +344.38%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 185.10 49.98 -73.00%
benchmarks/test_objectives_benchmarks.py::test_iql_speed[reduce-overhead-None] 77.08 101.94 +32.24%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 2,735 3,576 +30.75%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3,609 2,742 -24.03%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 3,368 2,601 -22.78%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3,575 2,935 -17.92%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1,865 2,176 +16.65%
benchmarks/test_collectors_benchmark.py::test_single 5.8570 6.7896 +15.92%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[100-img_shape2-large_img] 439.63 378.28 -13.95%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 2,009 2,259 +12.47%
benchmarks/test_collectors_benchmark.py::test_single_with_rb_pixels 5.2123 4.6546 -10.70%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 483.84 534.70 +10.51%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[100-img_shape1-atari] 3,708 4,047 +9.14%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 2,699 2,937 +8.81%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 491.92 451.52 -8.21%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[200-img_shape3-large_batch] 730.43 671.52 -8.06%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 529.00 491.05 -7.18%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 795.92 740.00 -7.03%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-False-True] 39,192 36,520 -6.82%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_lazystack_then_write[100-img_shape2-large_img] 399.88 373.44 -6.61%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1,971 2,092 +6.15%
benchmarks/test_objectives_benchmarks.py::test_td3_speed[True-backward] 398.68 374.60 -6.04%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[True-backward] 977.41 919.45 -5.93%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 2,452 2,593 +5.75%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 166.80 174.86 +4.83%
benchmarks/test_non_tensor_env_benchmark.py::test_non_tensor_env_rollout_speed[1000-single-True] 1.3101 1.3710 +4.65%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[4-same] 6.6585 6.9666 +4.63%
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed[True-backward] 349.93 366.06 +4.61%
benchmarks/test_replaybuffer_benchmark.py::TestPrioritizedReplayBufferBenchmark::test_sample_mixed_devices[1000000-cuda_storage_cuda_samp... 1,455 1,515 +4.11%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1,019 978.94 -3.97%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 2,550 2,450 -3.93%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-True-False-True] 42,804 41,133 -3.90%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_stack_then_write[50-img_shape0-small] 843.88 876.55 +3.87%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 725.96 698.82 -3.74%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[True-None] 741.49 715.23 -3.54%
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed[True-None] 753.11 779.31 +3.48%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[safetensors] 23,325 22,542 -3.35%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-False-False-True] 31,354 30,302 -3.35%
benchmarks/test_envs_benchmark.py::test_simple 1.2451 1.2033 -3.35%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-False-True] 38,192 36,921 -3.33%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-True-True-True] 23,700 22,913 -3.32%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[50-img_shape0-small] 6,101 5,908 -3.16%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 162.12 167.24 +3.16%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[False-backward] 146.85 151.40 +3.10%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_lazystack_then_write[100-img_shape1-atari] 632.44 613.10 -3.06%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[200-img_shape3-large_batch] 321.01 311.30 -3.03%
benchmarks/test_rnn_reset_backends_benchmark.py::test_rnn_rollout_with_intermediate_resets[b256-t128-i32-h512-scan-True-0-lstm] 75.89 78.16 +2.98%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[True-None] 815.79 840.02 +2.97%
benchmarks/test_envs_benchmark.py::test_transformed 0.6961 0.7162 +2.89%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-True-True-True] 19,503 18,941 -2.88%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 1,246 1,280 +2.76%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_contiguous[100-img_shape2-large_img] 537.67 522.96 -2.74%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_without_rb_cuda[200-img_shape1-large_batch] 8.8848 8.6461 -2.69%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_storage_write_lazystack[100-img_shape1-atari] 701.61 682.77 -2.69%
benchmarks/test_storage_write_benchmark.py::TestStorageWriteBenchmark::test_collector_stack_then_write[100-img_shape2-large_img] 167.38 171.81 +2.65%
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 167.10 171.37 +2.56%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[False-backward] 237.91 231.90 -2.53%
benchmarks/test_objectives_benchmarks.py::test_iql_speed[False-backward] 68.89 70.61 +2.51%
benchmarks/test_envs_benchmark.py::test_cat_frames_functional[16-same] 5.5487 5.6872 +2.50%
benchmarks/test_collectors_benchmark.py::test_async 10.87 11.14 +2.48%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[torch.save] 7,092 7,260 +2.37%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 845.72 826.10 -2.32%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-True-True] 22,781 22,271 -2.24%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[pickle] 12,096 12,360 +2.18%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 166.62 170.23 +2.17%
benchmarks/test_objectives_benchmarks.py::test_cql_speed[True-backward] 222.24 227.05 +2.16%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[False-backward] 448.10 457.78 +2.16%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[True-backward] 324.19 317.29 -2.13%
benchmarks/test_objectives_benchmarks.py::test_td3_speed[False-backward] 81.99 83.69 +2.08%
benchmarks/test_envs_benchmark.py::test_serial 0.4192 0.4277 +2.04%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[False-None] 110.80 113.05 +2.03%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-False-True-False] 32,731 32,085 -1.97%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[False-None] 345.83 339.13 -1.94%
benchmarks/test_objectives_benchmarks.py::test_dqn_speed[True-None] 1,928 1,891 -1.91%
benchmarks/test_objectives_benchmarks.py::test_sac_speed[reduce-overhead-None] 101.26 103.18 +1.89%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 1,303 1,328 +1.86%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-True-True] 21,046 20,662 -1.82%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[reduce-overhead-None] 788.79 803.12 +1.82%
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 192.48 195.94 +1.80%
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed[False-backward] 278.24 273.26 -1.79%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-True-False-False] 58,470 57,425 -1.79%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_with_rb[200-img_shape1-large_batch] 13.35 13.11 -1.78%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[untyped_storage] 8.3642 8.5120 +1.77%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_with_rb_cuda[100-img_shape0-atari] 17.06 16.76 -1.73%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1,851 1,883 +1.71%
benchmarks/test_objectives_benchmarks.py::test_ppo_speed[False-None] 225.93 229.78 +1.70%
benchmarks/test_objectives_benchmarks.py::test_redq_deprec_speed[False-backward] 73.13 71.89 -1.70%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-True-True] 19,044 18,724 -1.68%
benchmarks/test_objectives_benchmarks.py::test_a2c_speed[reduce-overhead-None] 847.19 859.95 +1.51%
benchmarks/test_non_tensor_env_benchmark.py::test_non_tensor_env_rollout_speed[1000-serial-no-buffers-True] 0.5894 0.5982 +1.49%
benchmarks/test_collectors_benchmark.py::test_single_with_rb 5.8786 5.9647 +1.46%
benchmarks/test_objectives_benchmarks.py::test_iql_speed[True-backward] 247.39 250.90 +1.42%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-True-False-True-False] 39,552 38,998 -1.40%
benchmarks/test_replaybuffer_benchmark.py::test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 52.56 53.28 +1.37%
benchmarks/test_objectives_benchmarks.py::test_td3_speed[reduce-overhead-None] 44.06 44.65 +1.33%
benchmarks/test_objectives_benchmarks.py::test_cql_speed[True-None] 369.07 373.94 +1.32%
benchmarks/test_collectors_benchmark.py::test_sync 10.43 10.55 +1.24%
benchmarks/test_collectors_benchmark.py::test_async_pixels 10.68 10.82 +1.24%
benchmarks/test_replaybuffer_benchmark.py::TestPrioritizedReplayBufferBenchmark::test_sample_mixed_devices[1000000-cuda_storage_cpu_sampler] 88.45 89.54 +1.23%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_with_rb_cuda[200-img_shape1-large_batch] 8.4961 8.3918 -1.23%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[reduce-overhead-None] 828.58 838.59 +1.21%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-True-False-True] 32,788 32,402 -1.18%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-True-True-True] 20,882 20,637 -1.18%
benchmarks/test_objectives_benchmarks.py::test_values[generalized_advantage_estimate-True-True] 49.14 48.56 -1.17%
benchmarks/test_objectives_benchmarks.py::test_values[vec_td_lambda_return_estimate-True-False] 853.50 863.20 +1.14%
benchmarks/test_storage_write_benchmark.py::TestCollectorIntegrationBenchmark::test_collector_without_rb_cuda[100-img_shape0-atari] 17.73 17.53 -1.11%
benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed[numpy] 376,401 380,548 +1.10%
benchmarks/test_collectors_benchmark.py::test_sync_pixels 10.30 10.41 +1.09%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 164.74 162.95 -1.09%
benchmarks/test_objectives_benchmarks.py::test_gae_speed[generalized_advantage_estimate-False-1-512] 48.94 49.46 +1.05%
benchmarks/test_objectives_benchmarks.py::test_values[td1_return_estimate-False-False] 20.05 20.25 +1.04%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-True-True-False] 35,340 35,704 +1.03%
benchmarks/test_objectives_benchmarks.py::test_iql_speed[True-None] 514.13 509.00 -1.00%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-True-False-True-True] 20,225 20,424 +0.98%
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed[True-backward] 449.01 453.26 +0.94%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[False-False-False-False-False] 45,791 46,221 +0.94%
benchmarks/test_objectives_benchmarks.py::test_cql_speed[False-backward] 40.67 40.29 -0.93%
benchmarks/test_envs_benchmark.py::test_step_mdp_speed[True-False-False-True-True] 19,980 19,795 -0.92%
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 163.05 164.56 +0.92%
... ... ... Showing 120 of 202 comparisons, sorted by absolute change.

@aehebald aehebald force-pushed the tensorclass-iql-loss branch from eefcbc8 to 089fef6 Compare June 13, 2026 15:37

@vmoens vmoens left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for contributing this.

I'd like to hold off here. To accept a tensorclass we'd need the user to pre-register internal keys like td_error and _log_prob on their class (or change the way the loss writes them). Since these write-backs are how essentially every off-policy loss behaves, this really shapes the whole #1062 approach rather than just IQL, so I'd rather settle the pattern first.

A few options: (1) skip these writes when the input is a tensorclass, (2) error out clearly if the fields aren't pre-declared, or (3) return them as metadata. I lean toward (1) — though td_error is the tricky one, since prioritized replay reads it back from the batch, so that case needs a bit more thought.

@aehebald

Copy link
Copy Markdown
Author

@vmoens

Thanks! Dug into the code and here's what I have in mind. The framing I landed on is that the loss could just avoid writing into its input container at all. Reads are fine on a tensorclass, it's only the writes that hit the fixed schema. There are three of them and I think each has a reasonable home:

The network outputs (loc/scale from the actor, state_value/state_action_value from the value/qvalue nets) only need somewhere writable to land. So instead of converting the input, I'd have the loss run each network in its own scratch TensorDict built from the inputs it reads. The input is only ever read from.

_log_prob is purely internal, since forward only reads it back to log entropy, so I'd return it through the metadata dict actor_loss already returns (the same channel qvalue_loss uses for td_error) instead of writing it into the input.

td_error is the one you flagged, since prioritized replay reads it back. Since update_priority(index, priority) takes them as separate args, the priority doesn't need to sit on the batch, so I'd surface it in the loss output, keep writing it in place when the input accepts new keys (TensorDict unchanged), and skip that write for a tensorclass. A tensorclass user then pulls the index off the sample and the priority off the output. The one thing that wouldn't work is the update_tensordict_priority(batch) convenience method for a tensorclass, but update_priority covers it.

How does this sound to you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Feature New feature Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants