Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .github/unittest/linux_libs/scripts_chess/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ fi
# submodules
git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with cu128"
printf "Installing PyTorch and torchvision with cu128"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
else
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps
else
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps
fi
else
printf "Failed to install pytorch"
Expand Down
14 changes: 9 additions & 5 deletions .github/unittest/linux_libs/scripts_jumanji/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,22 @@ fi
# submodules
git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with cu128"
printf "Installing PyTorch and torchvision with cu128"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
else
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps
else
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps
fi
else
printf "Failed to install pytorch"
Expand Down
2 changes: 2 additions & 0 deletions .github/unittest/linux_libs/scripts_libero/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ libero_dir="${root_dir}/libero-src"
rm -rf "${libero_dir}"
git clone --depth 1 https://github.com/Lifelong-Robot-Learning/LIBERO.git "${libero_dir}"

# robosuite 1.4.0 calls the pre-3.10 mj_fullM signature.
uv_pip_install \
"bddl==1.0.1" \
easydict \
"gym==0.25.2" \
h5py \
imageio \
matplotlib \
"mujoco<3.10.0" \
"numpy<2" \
opencv-python \
"robosuite==1.4.0" \
Expand Down
16 changes: 9 additions & 7 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torchrl.envs.common import _EnvPostInit
from torchrl.envs.utils import _classproperty

_has_torchvision = importlib.util.find_spec("torchvision") is not None


class _ChessMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
Expand Down Expand Up @@ -333,7 +335,7 @@ def __init__(
raise ImportError(
"Please install cairosvg to use this environment with pixel rendering."
)
if importlib.util.find_spec("torchvision") is None:
if not _has_torchvision:
raise ImportError(
"Please install torchvision to use this environment with pixel rendering."
)
Expand Down Expand Up @@ -466,19 +468,19 @@ def _torchvision(cls):
@classmethod
def _get_tensor_image(cls, board):
try:
from PIL import Image

svg = board._repr_svg_()
# Convert SVG to PNG using cairosvg
png_data = io.BytesIO()
cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data)
png_data.seek(0)
# Open the PNG image using Pillow
img = Image.open(png_data)
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
# Decode the PNG bytes directly into a CHW tensor.
img = cls._torchvision.io.decode_image(
torch.frombuffer(png_data.getbuffer(), dtype=torch.uint8),
mode=cls._torchvision.io.ImageReadMode.RGB,
)
except ImportError:
raise ImportError(
"Chess rendering requires cairosvg, PIL and torchvision to be installed."
"Chess rendering requires cairosvg and torchvision to be installed."
)
return img

Expand Down
33 changes: 25 additions & 8 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchrl.envs.utils import _classproperty

_has_jumanji = importlib.util.find_spec("jumanji") is not None
_has_torchvision = importlib.util.find_spec("torchvision") is not None

from torchrl.data.tensor_specs import (
Bounded,
Expand Down Expand Up @@ -352,6 +353,17 @@ def lib(self):
raise ImportError("jumanji version must be >= 1.0.0")
return jumanji

_torchvision_lib = None

@_classproperty
def _torchvision(cls):
tv = cls._torchvision_lib
if tv is None:
import torchvision

tv = cls._torchvision_lib = torchvision
return tv

def __init__(
self,
env: jumanji.env.Environment = None, # noqa: F821
Expand Down Expand Up @@ -579,14 +591,17 @@ def render(
import jax.numpy as jnp
import jumanji

if not _has_torchvision:
raise ImportError(
"Rendering with Jumanji requires torchvision to be installed."
)

try:
import matplotlib
import matplotlib.pyplot as plt
import PIL
import torchvision.transforms.v2.functional
except ImportError as err:
raise ImportError(
"Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed."
"Rendering with Jumanji requires matplotlib to be installed."
) from err

if matplotlib_backend is not None:
Expand Down Expand Up @@ -615,15 +630,17 @@ def render(
self._env.render(state, **kwargs)
plt.savefig(buf, format="png")
buf.seek(0)
# Load the image into a PIL object.
img = PIL.Image.open(buf)
img_array = torchvision.transforms.v2.functional.pil_to_tensor(img)
# Decode the PNG bytes directly into a CHW tensor.
img_array = self._torchvision.io.decode_image(
torch.frombuffer(buf.getbuffer(), dtype=torch.uint8),
mode=self._torchvision.io.ImageReadMode.RGB,
)
if not isinteractive:
plt.ioff()
plt.close()
if not as_numpy:
return img_array[:3]
return img_array[:3].numpy().copy()
return img_array
return img_array.numpy().copy()
finally:
jumanji.environments.is_notebook = is_notebook

Expand Down
Loading