Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/kwneuro/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,10 @@ def estimate_noddi(
mask: VolumeResource | None = None,
dpar: float = 1.7e-3,
n_kernel_dirs: int = 500,
regenerate_kernels: bool = True,
) -> Noddi:
"""Estimate NODDI model parameters from this DWI. See :meth:`kwneuro.noddi.Noddi.estimate_noddi` for details."""
return Noddi.estimate_noddi(self, mask, dpar, n_kernel_dirs) # type: ignore[no-any-return]
return Noddi.estimate_noddi(self, mask, dpar, n_kernel_dirs, regenerate_kernels) # type: ignore[no-any-return]


def subsample_dwi(dwi: Dwi, factor: int = 2) -> Dwi:
Expand Down
4 changes: 3 additions & 1 deletion src/kwneuro/noddi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def estimate_noddi(
mask: VolumeResource | None = None,
dpar: float = 1.7e-3,
n_kernel_dirs: int = 500,
regenerate_kernels: bool = True,
) -> Noddi:
"""Estimate Noddi from a DWI.

Expand All @@ -96,6 +97,8 @@ def estimate_noddi(
1.7e-3 mm^2/s is used, which is suitable for white matter. For gray matter, a value of 1.3e-3 mm^2/s is recommended.
n_kernel_dirs: The number of directions to use when generating the AMICO NODDI kernels. This value represents the total
count of possible orientations for the response functions across the half-sphere. Default: 500.
regenerate_kernels: If True, delete and recompute AMICO kernels before fitting. Set False (default)
for batch jobs where kernels are shared across subjects; set True only when dpar or n_kernel_dirs changes.

Returns: A Noddi resource containing the estimated parameters.
"""
Expand Down Expand Up @@ -146,7 +149,6 @@ def estimate_noddi(
ae.set_model("NODDI")
ae.model.dPar = dpar

regenerate_kernels = True
ae.generate_kernels(regenerate=regenerate_kernels, ndirs=n_kernel_dirs)
ae.load_kernels()
ae.fit()
Expand Down
113 changes: 113 additions & 0 deletions src/kwneuro/reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,119 @@ def register_volumes(
return (warpedmovout, transform)


def register_volumes_multimetric(
fixed: dict[str, VolumeResource],
moving: dict[str, VolumeResource],
weights: dict[str, float] | None = None,
mask: VolumeResource | None = None,
moving_mask: VolumeResource | None = None,
) -> tuple[dict[str, InMemoryVolumeResource], TransformResource]:
"""
Registers multiple moving volumes to a fixed reference using multi-metric ANTs.

The primary modality (first key) drives an initial Affine stage; all modalities
then contribute to a SyNOnly deformable stage via ``multivariate_extras``.
Requires at least two modalities; use :func:`register_volumes` for
single-modality registration.

Args:
fixed: Mapping of modality name to fixed reference volume. The first key is
the primary modality.
moving: Mapping of modality name to moving volume. Must have the same keys
as ``fixed``.
weights: Per-modality weights for the deformable stage, normalised relative
to the primary modality. Defaults to equal weighting.
mask: Optional mask for the primary fixed image (Affine stage only).
moving_mask: Optional mask for the primary moving image (Affine stage only).

Returns:
A tuple of a warped volumes dict (modality name : InMemoryVolumeResource)
and a :class:`TransformResource` mapping moving space to fixed space.

"""
if set(fixed.keys()) != set(moving.keys()):
error_message = "fixed and moving must have the same modality keys."
raise ValueError(error_message)

modalities = list(fixed.keys())
if len(modalities) < 2:
error_message = (
"register_volumes_multimetric requires at least 2 modalities. "
"Use register_volumes for single-modality registration."
)
raise ValueError(error_message)
primary_mod = modalities[0]

# Load and validate all volumes
ants_fixed: dict[str, ants.ANTsImage] = {}
ants_moving: dict[str, ants.ANTsImage] = {}
for m in modalities:
fixed_loaded = fixed[m].load()
moving_loaded = moving[m].load()
if fixed_loaded.get_array().ndim > 3 or moving_loaded.get_array().ndim > 3:
error_message = "Input volume dimensions must be 2D or 3D."
raise ValueError(error_message)
ants_fixed[m] = fixed_loaded.to_ants_image()
ants_moving[m] = moving_loaded.to_ants_image()

ants_mask = mask.load().to_ants_image() if mask is not None else None
ants_moving_mask = (
moving_mask.load().to_ants_image() if moving_mask is not None else None
)

# Normalise weights relative to the primary modality
if weights is None:
weights = dict.fromkeys(modalities, 1.0)
norm_weights = {m: weights[m] / weights[primary_mod] for m in modalities}

# Step 1: Affine registration on primary modality
affine_result = ants.registration(
fixed=ants_fixed[primary_mod],
moving=ants_moving[primary_mod],
mask=ants_mask,
moving_mask=ants_moving_mask,
type_of_transform="Affine",
)
affine_path: str = affine_result["fwdtransforms"][-1]

# Step 2: SyNOnly deformable registration with multivariate_extras
multivariate_metrics = [
["mattes", ants_fixed[m], ants_moving[m], norm_weights[m], 1]
for m in modalities[1:]
]
deformable_result = ants.registration(
fixed=ants_fixed[primary_mod],
moving=ants_moving[primary_mod],
multivariate_extras=multivariate_metrics,
type_of_transform="SyNOnly",
initial_transform=affine_path,
)
warp_path: str = deformable_result["fwdtransforms"][0]

# Combine affine + warp into a TransformResource (same path layout as SyN)
transform = TransformResource(
_ants_fwd_paths=[warp_path, affine_path],
_ants_inv_paths=[affine_path] + deformable_result["invtransforms"],
)

# Warp all modalities using the combined transform
warped: dict[str, InMemoryVolumeResource] = {
primary_mod: InMemoryVolumeResource.from_ants_image(
deformable_result["warpedmovout"]
)
}
for m in modalities[1:]:
warped_ants = ants.apply_transforms(
fixed=ants_fixed[m],
moving=ants_moving[m],
transformlist=[warp_path, affine_path],
whichtoinvert=[False, False],
)
warped[m] = InMemoryVolumeResource.from_ants_image(warped_ants)

return warped, transform


@cacheable
def register_dwi_to_structural(
dwi: Dwi,
Expand Down
69 changes: 68 additions & 1 deletion tests/test_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from kwneuro import Cache
from kwneuro.dwi import Dwi
from kwneuro.reg import TransformResource, register_dwi_to_structural, register_volumes
from kwneuro.reg import (
TransformResource,
register_dwi_to_structural,
register_volumes,
register_volumes_multimetric,
)
from kwneuro.resource import (
InMemoryBvalResource,
InMemoryBvecResource,
Expand Down Expand Up @@ -288,3 +293,65 @@ def test_register_volumes_with_incorrect_dimensions(

with pytest.raises(ValueError, match="Input volume dimensions must be"):
register_volumes(fixed=dwi2.volume, moving=correct_dim)


def test_register_volumes_multimetric(dwi1: Dwi, dwi2: Dwi, tmp_path):
vol_fixed = dwi1.compute_mean_b0()
vol_moving = dwi2.compute_mean_b0()

warped, transform = register_volumes_multimetric(
fixed={"primary": vol_fixed, "secondary": vol_fixed},
moving={"primary": vol_moving, "secondary": vol_moving},
)

assert set(warped.keys()) == {"primary", "secondary"}
assert isinstance(warped["primary"], InMemoryVolumeResource)
assert isinstance(warped["secondary"], InMemoryVolumeResource)
assert warped["primary"].get_array().shape == vol_fixed.get_array().shape
assert warped["secondary"].get_array().shape == vol_fixed.get_array().shape

# fwd: [warp.nii.gz, affine.mat]; inv: [affine.mat, InverseWarp.nii.gz]
assert len(transform._ants_fwd_paths) == 2
assert len(transform._ants_inv_paths) == 2
assert transform.matrices is not None
assert transform.warp_fields is not None
assert len(transform.matrices) == 1
assert len(transform.warp_fields) == 1
assert isinstance(transform.matrices[0], ants.ANTsTransform)

# Check saving
transform.save(tmp_path)
for file in transform._ants_fwd_paths + transform._ants_inv_paths:
assert (tmp_path / Path(file).name).exists()

# Check application
applied = transform.apply(fixed=vol_fixed, moving=vol_moving)
assert applied.get_array().shape == vol_fixed.get_array().shape


def test_register_volumes_multimetric_mismatched_keys(dwi1: Dwi):
vol = dwi1.compute_mean_b0()
with pytest.raises(ValueError, match="same modality keys"):
register_volumes_multimetric(
fixed={"primary": vol, "secondary": vol},
moving={"primary": vol, "other": vol},
)


def test_register_volumes_multimetric_single_modality(dwi1: Dwi):
vol = dwi1.compute_mean_b0()
with pytest.raises(ValueError, match="at least 2 modalities"):
register_volumes_multimetric(
fixed={"primary": vol},
moving={"primary": vol},
)


def test_register_volumes_multimetric_wrong_dimensions(dwi1: Dwi):
scalar = dwi1.compute_mean_b0()
four_d = dwi1.volume # 4D volume
with pytest.raises(ValueError, match="Input volume dimensions must be"):
register_volumes_multimetric(
fixed={"primary": scalar, "secondary": four_d},
moving={"primary": scalar, "secondary": scalar},
)
Loading