Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,57 @@ def gwcs_2d_lt_ln():

return (wcs.WCS(forward_transform=cel_model, output_frame=sky_frame, input_frame=input_frame))


@pytest.fixture
def gwcs_2d_t_f_linear():
"""
2D gWCS for a dynamic spectrum: uniform time (array axis 1 / X) and linear
frequency (array axis 0 / Y).
"""
time_model = models.Scale(14.0)
freq_model = models.Scale(1e6)

time_frame = cf.TemporalFrame(axes_order=(0,), unit=u.s,
reference_frame=Time("2024-03-23T00:03:23"))
freq_frame = cf.SpectralFrame(axes_order=(1,), unit=u.Hz, axes_names=('frequency',))

transform = time_model & freq_model
frame = cf.CompositeFrame([time_frame, freq_frame])
detector_frame = cf.CoordinateFrame(name="detector", naxes=2,
axes_order=(0, 1),
axes_type=("pixel", "pixel"),
unit=(u.pix, u.pix))
return wcs.WCS(forward_transform=transform, output_frame=frame,
input_frame=detector_frame)


@pytest.fixture
def gwcs_2d_t_f_log():
"""
2D gWCS for a dynamic spectrum: irregularly-spaced time (array axis 1 / X)
and log-spaced frequency (array axis 0 / Y) via Tabular1D lookup tables.
"""
times_s = np.array([0.0, 14.0, 27.4, 41.1, 55.2, 67.8, 82.3, 95.9, 109.1, 122.5])
freqs_hz = np.logspace(np.log10(3.992e6), np.log10(978.572e6), 16)

time_model = models.Tabular1D(points=np.arange(10), lookup_table=times_s,
method='linear', bounds_error=False)
freq_model = models.Tabular1D(points=np.arange(16), lookup_table=freqs_hz,
method='linear', bounds_error=False)

time_frame = cf.TemporalFrame(axes_order=(0,), unit=u.s,
reference_frame=Time("2024-03-23T00:03:23"))
freq_frame = cf.SpectralFrame(axes_order=(1,), unit=u.Hz, axes_names=('frequency',))

transform = time_model & freq_model
frame = cf.CompositeFrame([time_frame, freq_frame])
detector_frame = cf.CoordinateFrame(name="detector", naxes=2,
axes_order=(0, 1),
axes_type=("pixel", "pixel"),
unit=(u.pix, u.pix))
return wcs.WCS(forward_transform=transform, output_frame=frame,
input_frame=detector_frame)

@pytest.fixture
def wcs_4d_t_l_lt_ln():
header = {
Expand Down Expand Up @@ -564,6 +615,20 @@ def extra_coords_sharing_axis():
# NOTE: If you add more fixtures please add to the all_ndcubes fixture
################################################################################

@pytest.fixture
def ndcube_gwcs_2d_t_f_linear(gwcs_2d_t_f_linear):
shape = (16, 10) # (n_freq, n_time): freq on Y axis, time on X axis
gwcs_2d_t_f_linear.array_shape = shape
return NDCube(data_nd(shape), wcs=gwcs_2d_t_f_linear)


@pytest.fixture
def ndcube_gwcs_2d_t_f_log(gwcs_2d_t_f_log):
shape = (16, 10) # (n_freq, n_time): freq on Y axis, time on X axis
gwcs_2d_t_f_log.array_shape = shape
return NDCube(data_nd(shape), wcs=gwcs_2d_t_f_log)


@pytest.fixture
def ndcube_gwcs_4d_ln_lt_l_t(gwcs_4d_t_l_lt_ln):
shape = (5, 8, 10, 12)
Expand Down Expand Up @@ -1074,6 +1139,8 @@ def ndcube_1d_l(wcs_1d_l):


@pytest.fixture(params=[
"ndcube_gwcs_2d_t_f_linear",
"ndcube_gwcs_2d_t_f_log",
"ndcube_gwcs_4d_ln_lt_l_t",
"ndcube_gwcs_4d_ln_lt_l_t_unit",
"ndcube_gwcs_3d_ln_lt_l",
Expand Down
135 changes: 135 additions & 0 deletions ndcube/tests/test_ndcube_dynspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
Tests to simulate dynamic spectrum WCSes (frequency x time).
"""
import pytest
from numpy.testing import assert_allclose

import astropy.units as u

from ndcube.wcs.wrappers import ResampledLowLevelWCS


def _world_at(cube, time_pixel, freq_pixel):
return cube.wcs.low_level_wcs.pixel_to_world_values(time_pixel, freq_pixel)


@pytest.mark.parametrize("ndc", [
"ndcube_gwcs_2d_t_f_linear",
"ndcube_gwcs_2d_t_f_log",
], indirect=True)
def test_dynspec_array_axis_physical_types(ndc):
types = ndc.array_axis_physical_types
assert "em.freq" in types[0]
assert "time" in types[1]


def test_linear_dynspec_pixel_to_world(ndcube_gwcs_2d_t_f_linear):
time, freq = ndcube_gwcs_2d_t_f_linear.wcs.low_level_wcs.pixel_to_world_values(3, 2)
assert_allclose(time, 42.0)
assert_allclose(freq, 2e6)


def test_linear_dynspec_world_to_pixel(ndcube_gwcs_2d_t_f_linear):
pix_t, pix_f = ndcube_gwcs_2d_t_f_linear.wcs.low_level_wcs.world_to_pixel_values(28.0, 4e6)
assert_allclose(pix_t, 2.0)
assert_allclose(pix_f, 4.0)


@pytest.mark.parametrize(("bin_shape", "expected_shape", "expected_time", "expected_freq"), [
((2, 1), (8, 10), 0.0, 0.5e6),
((1, 2), (16, 5), 7.0, 0.0),
])
def test_linear_dynspec_rebin_wcs(ndcube_gwcs_2d_t_f_linear, bin_shape,
expected_shape, expected_time, expected_freq):
rebinned = ndcube_gwcs_2d_t_f_linear.rebin(bin_shape)
time0, freq0 = rebinned.wcs.low_level_wcs.pixel_to_world_values(0, 0)

assert rebinned.shape == expected_shape
assert isinstance(rebinned.wcs.low_level_wcs, ResampledLowLevelWCS)
assert_allclose(time0, expected_time)
assert_allclose(freq0, expected_freq)


@pytest.mark.parametrize(("lower_corner", "upper_corner", "expected_shape"), [
([None, 3e6 * u.Hz], [None, 7e6 * u.Hz], (5, 10)),
([14 * u.s, None], [56 * u.s, None], (16, 4)),
])
def test_linear_dynspec_crop_by_values_shape(ndcube_gwcs_2d_t_f_linear,
lower_corner, upper_corner,
expected_shape):
cropped = ndcube_gwcs_2d_t_f_linear.crop_by_values(lower_corner, upper_corner)
assert cropped.shape == expected_shape


def test_log_dynspec_world_axis_units(ndcube_gwcs_2d_t_f_log):
assert ndcube_gwcs_2d_t_f_log.wcs.world_axis_units == ("s", "Hz")


@pytest.mark.parametrize(("time_pixel", "freq_pixel", "expected_time", "expected_freq"), [
(0, 0, 0.0, 3.992e6),
(9, 15, 122.5, 978.572e6),
])
def test_log_dynspec_pixel_to_world_endpoints(ndcube_gwcs_2d_t_f_log,
time_pixel, freq_pixel,
expected_time, expected_freq):
time, freq = ndcube_gwcs_2d_t_f_log.wcs.low_level_wcs.pixel_to_world_values(
time_pixel, freq_pixel)
assert_allclose(time, expected_time)
assert_allclose(freq, expected_freq, rtol=1e-6)


def test_log_dynspec_world_to_pixel_roundtrip(ndcube_gwcs_2d_t_f_log):
time, freq = _world_at(ndcube_gwcs_2d_t_f_log, 3, 7)
pix_t, pix_f = ndcube_gwcs_2d_t_f_log.wcs.low_level_wcs.world_to_pixel_values(
time, freq)
assert_allclose(pix_t, 3.0, atol=1e-10)
assert_allclose(pix_f, 7.0, atol=1e-10)


@pytest.mark.parametrize(("bin_shape", "expected_shape", "axis"), [
((2, 1), (8, 10), "freq"),
((1, 2), (16, 5), "time"),
])
def test_log_dynspec_rebin_wcs_midpoint(ndcube_gwcs_2d_t_f_log, bin_shape,
expected_shape, axis):
rebinned = ndcube_gwcs_2d_t_f_log.rebin(bin_shape)
time0, freq0 = rebinned.wcs.low_level_wcs.pixel_to_world_values(0, 0)

assert rebinned.shape == expected_shape
assert isinstance(rebinned.wcs.low_level_wcs, ResampledLowLevelWCS)
if axis == "freq":
_, freq_left = _world_at(ndcube_gwcs_2d_t_f_log, 0, 0)
_, freq_right = _world_at(ndcube_gwcs_2d_t_f_log, 0, 1)
assert_allclose(freq0, (freq_left + freq_right) / 2, rtol=1e-6)
else:
time_left, _ = _world_at(ndcube_gwcs_2d_t_f_log, 0, 0)
time_right, _ = _world_at(ndcube_gwcs_2d_t_f_log, 1, 0)
assert_allclose(time0, (time_left + time_right) / 2, rtol=1e-6)


@pytest.mark.parametrize(("lower_corner", "upper_corner", "expected_shape",
"axis", "bounds"), [
([None, 10e6 * u.Hz], [None, 100e6 * u.Hz], (8, 10), "freq", (10e6, 100e6)),
([20 * u.s, None], [80 * u.s, None], (16, 6), "time", (20.0, 80.0)),
])
def test_log_dynspec_crop_by_values_single_axis(ndcube_gwcs_2d_t_f_log,
lower_corner, upper_corner,
expected_shape, axis, bounds):
cropped = ndcube_gwcs_2d_t_f_log.crop_by_values(lower_corner, upper_corner)
assert cropped.shape == expected_shape

if axis == "freq":
values = [cropped.wcs.low_level_wcs.pixel_to_world_values(0, i)[1]
for i in range(cropped.shape[0])]
else:
values = [cropped.wcs.low_level_wcs.pixel_to_world_values(i, 0)[0]
for i in range(cropped.shape[1])]

assert values[0] <= bounds[0]
assert values[-1] >= bounds[1]


def test_log_dynspec_crop_by_freq_and_time(ndcube_gwcs_2d_t_f_log):
cropped = ndcube_gwcs_2d_t_f_log.crop_by_values(
[20 * u.s, 10e6 * u.Hz], [80 * u.s, 100e6 * u.Hz])
assert cropped.shape == (8, 6)
22 changes: 22 additions & 0 deletions ndcube/tests/test_ndcube_slice_and_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,25 @@ def test_crop_all_points_beyond_cube_extent_error(points):

with pytest.raises(ValueError, match="are outside the range of the NDCube being cropped"):
cube.crop(*points, keepdims=True)


def test_crop_by_values_quantity_table_coordinate():
# Regression: QuantityTableCoordinate-based WCS raised
# "High Level objects are not supported with the native API" because
# world_to_pixel_values received Quantity objects instead of plain values.
freqs_hz = np.logspace(np.log10(4e6), np.log10(200e6), 16)
times_s = np.linspace(0, 140, 10)
wcs2d = astropy.wcs.WCS(naxis=2)
wcs2d.wcs.ctype = ["PIXEL", "PIXEL"]
wcs2d.wcs.crpix = [1, 1]
wcs2d.wcs.cdelt = [1, 1]
wcs2d.wcs.crval = [0, 0]
data = np.arange(16 * 10).reshape(16, 10)
cube = NDCube(data, wcs=wcs2d)
cube.extra_coords.add("frequency", (0,), freqs_hz * u.Hz)
cube.extra_coords.add("time", (1,), times_s * u.s)

cropped = cube.crop_by_values([10e6 * u.Hz, 20 * u.s], [100e6 * u.Hz, 80 * u.s],
wcs=cube.extra_coords)
assert cropped.shape == (10, 5)
np.testing.assert_array_equal(cropped.data, data[3:13, 1:6])
16 changes: 14 additions & 2 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,20 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims, original_sh
# Derive the pixel indices of the input point and place each index
# in the list corresponding to its axis.
# Use the to_pixel methods to preserve fractional indices for future rounding.
point_pixel_indices = (sliced_wcs.world_to_pixel_values(*sliced_point) if crop_by_values
else HighLevelWCSWrapper(sliced_wcs).world_to_pixel(*sliced_point))
if crop_by_values:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is quite true but it did break a unit test.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a real regression in the oldest deps, gwcs looks like it used to handle this internally?!

# world_to_pixel_values is APE14 low-level API and expects plain
# floats, not Quantity objects. So we need to strip units here; the values are
# already in the correct units because _get_crop_by_values_item called
# .to(wcs.world_axis_units[j]) before reaching this point.
#
# Passing Quantity objects raises TypeError in gWCS when the WCS's
# declared high-level type is itself Quantity (e.g., a WCS built from
# QuantityTableCoordinate), because gWCS cannot distinguish such inputs
# from an accidental high-level API call.
stripped_point = [p.value if hasattr(p, "value") else p for p in sliced_point]
point_pixel_indices = sliced_wcs.world_to_pixel_values(*stripped_point)
else:
point_pixel_indices = HighLevelWCSWrapper(sliced_wcs).world_to_pixel(*sliced_point)
# For each pixel axis associated with this point, place the pixel coords for
# that pixel axis into the corresponding list within combined_points_pixel_idx.
if sliced_wcs.pixel_n_dim == 1:
Expand Down
Loading