Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, rankseg]
```

which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `rankseg` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
43 changes: 43 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import (
OptionalImportError,
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
optional_import,
)
from monai.utils.type_conversion import convert_to_dst_type

rankseg_fn, has_rankseg = optional_import("rankseg.functional", name="rankseg")

__all__ = [
"Activations",
"AsDiscrete",
Expand Down Expand Up @@ -142,6 +146,7 @@ class AsDiscrete(Transform):
Convert the input tensor/array into discrete values, possible operations are:

- `argmax`.
- `rankseg`.
- threshold input value to binary values.
- convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes).
- round the value to the closest integer.
Expand All @@ -155,6 +160,14 @@ class AsDiscrete(Transform):
Defaults to ``None``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
RankSEG is applied to a channel-first probability map for one image; ``dim`` identifies the
class/channel dimension and is moved to the front before decoding. For the common MONAI
post-processing input shape ``(C, *spatial)``, use the default ``dim=0``.
The output is a label map. With the default ``keepdim=True``, the output shape is ``(1, *spatial)``;
with ``keepdim=False``, it is ``(*spatial)``. The ``dim`` and ``keepdim`` shape handling is aligned
with ``argmax``. This option is incompatible with ``argmax=True``.
Defaults to ``False``.
kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.
currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
These default to ``0``, ``True``, ``torch.float`` respectively.
Expand All @@ -173,6 +186,12 @@ class AsDiscrete(Transform):
>>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
# [[[0.0, 0.0]], [[1.0, 1.0]]]

RankSEG decoding requires the optional ``rankseg`` package:

>>> transform = AsDiscrete(rankseg=True)
>>> print(transform(np.array([[[0.3, 0.6]], [[0.7, 0.4]]])))
# [[[1.0, 1.0]]]

"""

backend = [TransformBackends.TORCH]
Expand All @@ -183,9 +202,13 @@ def __init__(
to_onehot: int | None = None,
threshold: float | None = None,
rounding: str | None = None,
rankseg: bool = False,
**kwargs,
) -> None:
if argmax and rankseg:
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
self.argmax = argmax
self.rankseg = rankseg
if isinstance(to_onehot, bool): # for backward compatibility
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
self.to_onehot = to_onehot
Expand All @@ -200,6 +223,7 @@ def __call__(
to_onehot: int | None = None,
threshold: float | None = None,
rounding: str | None = None,
rankseg: bool | None = None,
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -211,6 +235,10 @@ def __call__(
Defaults to ``self.to_onehot``.
threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
Defaults to ``self.threshold``.
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
Applies RankSEG to a channel-first probability map by default and uses the same ``dim`` and
``keepdim`` shape handling as ``argmax``. This option is incompatible with ``argmax=True``.
Defaults to ``self.rankseg``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

Expand All @@ -220,9 +248,24 @@ def __call__(
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor)
argmax = self.argmax if argmax is None else argmax
rankseg = self.rankseg if rankseg is None else rankseg

if argmax and rankseg:
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")

if argmax:
img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

if rankseg:
if not has_rankseg:
raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.")
# Adjust shape to meet RankSEG's [B, C, *spatial] input requirement.
channel_dim = self.kwargs.get("dim", 0) % img_t.ndim
keepdim = self.kwargs.get("keepdim", True)
img_t = rankseg_fn(img_t.movedim(channel_dim, 0).unsqueeze(0)).squeeze(0)
if keepdim:
img_t = img_t.unsqueeze(channel_dim)

to_onehot = self.to_onehot if to_onehot is None else to_onehot
if to_onehot is not None:
if not isinstance(to_onehot, int):
Expand Down
16 changes: 13 additions & 3 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
to_onehot: Sequence[int | None] | int | None = None,
threshold: Sequence[float | None] | float | None = None,
rounding: Sequence[str | None] | str | None = None,
rankseg: Sequence[bool] | bool = False,
allow_missing_keys: bool = False,
**kwargs,
) -> None:
Expand All @@ -182,6 +183,10 @@ def __init__(
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"]. it also can be a sequence of str or None,
each element corresponds to a key in ``keys``.
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
RankSEG expects channel-first probability maps for one image. It also can be a sequence of bool,
each element corresponds to a key in ``keys``. Uses the same ``dim`` and ``keepdim`` shape handling
as ``argmax``. This option is incompatible with ``argmax=True``.
allow_missing_keys: don't raise exception if key is missing.
kwargs: additional parameters to ``AsDiscrete``.
``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
Expand All @@ -190,6 +195,9 @@ def __init__(
"""
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.rankseg = ensure_tuple_rep(rankseg, len(self.keys))
if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg)):
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
self.to_onehot = []
for flag in ensure_tuple_rep(to_onehot, len(self.keys)):
if isinstance(flag, bool):
Expand All @@ -208,10 +216,12 @@ def __init__(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, argmax, to_onehot, threshold, rounding in self.key_iterator(
d, self.argmax, self.to_onehot, self.threshold, self.rounding
for key, argmax, to_onehot, threshold, rounding, rankseg in self.key_iterator(
d, self.argmax, self.to_onehot, self.threshold, self.rounding, self.rankseg
):
d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding)
d[key] = self.converter(
d[key], argmax=argmax, to_onehot=to_onehot, threshold=threshold, rounding=rounding, rankseg=rankseg
)
return d


Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ lpips==0.1.4
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
rankseg>=0.0.5
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
onnx_graphsurgeon
polygraphy
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ all =
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
rankseg>=0.0.5
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -179,6 +180,8 @@ huggingface_hub =
huggingface_hub
pyamg =
pyamg>=5.0.0, <5.3.0
rankseg =
rankseg>=0.0.5
# segment-anything =
# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything

Expand Down
34 changes: 34 additions & 0 deletions tests/transforms/test_as_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from __future__ import annotations

import unittest
from unittest import mock

from parameterized import parameterized

from monai.transforms import AsDiscrete
from monai.transforms.post import array as post_array
from monai.utils import OptionalImportError
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
Expand Down Expand Up @@ -63,6 +66,25 @@
[{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)]
)

TEST_CASES.append(
[
{"rankseg": False, "argmax": True},
p([[[0.3, 0.6]], [[0.7, 0.4]]]),
p([[[1.0, 0.0]]]),
(1, 1, 2),
]
)

if post_array.has_rankseg:
TEST_CASES.append(
[
{"rankseg": True},
p([[[0.3, 0.6]], [[0.7, 0.4]]]),
p([[[1.0, 1.0]]]),
(1, 1, 2),
]
)


class TestAsDiscrete(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand All @@ -76,6 +98,18 @@ def test_additional(self):
out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]]))
assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False)

def test_rankseg_argmax_incompatible(self):
with self.assertRaises(ValueError):
AsDiscrete(argmax=True, rankseg=True)

with self.assertRaises(ValueError):
AsDiscrete(argmax=True)([[[0.3, 0.6]], [[0.7, 0.4]]], rankseg=True)

def test_rankseg_missing_dependency(self):
with mock.patch("monai.transforms.post.array.has_rankseg", False):
with self.assertRaises(OptionalImportError):
AsDiscrete(rankseg=True)([[[0.3, 0.6]], [[0.7, 0.4]]])


if __name__ == "__main__":
unittest.main()
48 changes: 48 additions & 0 deletions tests/transforms/test_as_discreted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from __future__ import annotations

import unittest
from unittest import mock

from parameterized import parameterized

from monai.transforms import AsDiscreted
from monai.transforms.post import array as post_array
from monai.utils import OptionalImportError
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
Expand Down Expand Up @@ -66,6 +69,40 @@
]
)

TEST_CASES.append(
[
{"keys": "pred", "rankseg": False, "argmax": True},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])},
{"pred": p([[[1.0, 0.0]]])},
(1, 1, 2),
]
)

if post_array.has_rankseg:
TEST_CASES.append(
[
{"keys": "pred", "rankseg": True},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])},
{"pred": p([[[1.0, 1.0]]])},
(1, 1, 2),
]
)

TEST_CASES.append(
[
{"keys": ["pred", "label"], "rankseg": [True, False]},
{
"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]),
"label": p([[[0.0, 1.0]]]),
},
{
"pred": p([[[1.0, 1.0]]]),
"label": p([[[0.0, 1.0]]]),
},
(1, 1, 2),
]
)


class TestAsDiscreted(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand All @@ -77,6 +114,17 @@ def test_value_shape(self, input_param, test_input, output, expected_shape):
assert_allclose(result["label"], output["label"], rtol=1e-3, type_test="tensor")
self.assertTupleEqual(result["label"].shape, expected_shape)

def test_rankseg_argmax_incompatible(self):
with self.assertRaises(ValueError):
AsDiscreted(keys="pred", argmax=True, rankseg=True)(
{"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}
)

def test_rankseg_missing_dependency(self):
with mock.patch("monai.transforms.post.array.has_rankseg", False):
with self.assertRaises(OptionalImportError):
AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]})


if __name__ == "__main__":
unittest.main()
Loading