Skip to content

Don't pass envpool envs where vectorenvs are needed #1096

@MischaPanch

Description

@MischaPanch

See the block comments in test and in Collector method. Somewhere a pure envpool-env is passed instead of instances of BaseVectorEnv, thus the interface is not followed.

This means we rely on the two interfaces accidentally kind-of coinciding. They already don't fully coincide since envpool envs return an info as single dict with arrays, whereas tianshou's VectorEnvs return an array of dicts.

@Trinkle23897 this issue might be of interest to you

@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
    num_envs = 4
    env = VectorEnvNormObs(
        envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
    )
    obs, info = env.reset()
    assert obs.shape[0] == num_envs
    # This is not actually unreachable b/c envpool does not return info in the right format
    if isinstance(info, dict):  # type: ignore[unreachable]
        for _, v in info.items():  # type: ignore[unreachable]
            if not isinstance(v, dict):
                assert v.shape[0] == num_envs
    else:
        for _info in info:
            for _, v in _info.items():
                if not isinstance(v, dict):
                    assert v.shape[0] == num_envs
    def reset_env(
        self,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """Reset the environments and the initial obs, info, and hidden state of the collector."""
        gym_reset_kwargs = gym_reset_kwargs or {}
        self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
        # TODO: hack, wrap envpool envs such that they don't return a dict
        if isinstance(self._pre_collect_info_R, dict):  # type: ignore[unreachable]
            # this can happen if the env is an envpool env. Then the thing returned by reset is a dict
            # with array entries instead of an array of dicts
            # We use Batch to turn it into an array of dicts
            self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R)  # type: ignore[unreachable]

        self._pre_collect_hidden_state_RH = None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomersrefactoringNo change to functionality

    Type

    No type

    Projects

    Status

    To do

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions