From 18fe41b8318118c3d1d3a37bfd3f66fe526e8fc3 Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Sun, 22 Mar 2026 01:36:04 +0000 Subject: [PATCH] add keypoint dataset --- src/deepforest/augmentations.py | 12 +- src/deepforest/datasets/training.py | 389 +++++++++++++++--- src/deepforest/main.py | 43 +- src/deepforest/models/DeformableDetr.py | 2 + src/deepforest/models/retinanet.py | 2 + ...ing.py => test_datasets_training_boxes.py} | 0 tests/test_datasets_training_keypoint.py | 271 ++++++++++++ 7 files changed, 649 insertions(+), 70 deletions(-) rename tests/{test_datasets_training.py => test_datasets_training_boxes.py} (100%) create mode 100644 tests/test_datasets_training_keypoint.py diff --git a/src/deepforest/augmentations.py b/src/deepforest/augmentations.py index a6577068b..9ffe102ee 100644 --- a/src/deepforest/augmentations.py +++ b/src/deepforest/augmentations.py @@ -206,8 +206,9 @@ def get_available_augmentations() -> list[str]: def get_transform( augmentations: str | list[str] | dict[str, Any] | None = None, + data_keys: list[DataKey] | None = None, ) -> K.AugmentationSequential: - """Create Kornia transform for bounding boxes. + """Create Kornia transform pipeline. Args: augmentations: Augmentation configuration: @@ -215,6 +216,8 @@ def get_transform( - list: List of augmentation names - dict: Dict with names as keys and params as values - None: No augmentations + data_keys: Kornia DataKey list describing the inputs. + Defaults to [DataKey.IMAGE, DataKey.BBOX_XYXY]. Returns: Kornia AugmentationSequential @@ -235,6 +238,9 @@ def get_transform( ... "VerticalFlip" ... }) """ + if data_keys is None: + data_keys = [DataKey.IMAGE, DataKey.BBOX_XYXY] + transforms_list = [] if augmentations is not None: @@ -245,9 +251,7 @@ def get_transform( transforms_list.append(aug_transform) # Create a sequential container for all transforms - return K.AugmentationSequential( - *transforms_list, data_keys=[DataKey.IMAGE, DataKey.BBOX_XYXY] - ) + return K.AugmentationSequential(*transforms_list, data_keys=data_keys) def _parse_augmentations( diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 2e9cdfa52..7ed188f45 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -2,11 +2,15 @@ import math import os +from abc import abstractmethod +from typing import Any import kornia.augmentation as K import numpy as np import shapely import torch +import torchvision +from kornia.constants import DataKey from PIL import Image from torch.utils.data import Dataset from torchvision.datasets import ImageFolder @@ -15,23 +19,8 @@ from deepforest.augmentations import get_transform -class BoxDataset(Dataset): - """Dataset for object detection with bounding boxes. - - Args: - csv_file: Path to CSV file with annotations - root_dir: Directory containing images - transforms: Function applied to each sample - augment: Deprecated - use augmentations instead - augmentations: Augmentation configuration - label_dict: Mapping from string labels to class IDs - preload_images: Preload all images into memory - - Returns: - List of (image, target) pairs where target contains: - - "boxes": numpy array of shape (N, 4) - - "labels": numpy array of shape (N,) - """ +class TrainingDataset(Dataset): + _data_keys = [DataKey.IMAGE, DataKey.BBOX_XYXY] def __init__( self, @@ -51,11 +40,6 @@ def __init__( label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}). augmentations (str | list | dict, optional): Augmentation configuration. preload_images (bool): If True, preload all images into memory. Defaults to False. - - Returns: - list: A list of (image, target) pairs, where each target is a dict with: - - "boxes": numpy.ndarray of shape (N, 4) - - "labels": numpy.ndarray of shape (N,) """ self.annotations = utilities.read_file(csv_file, root_dir=root_dir) self.root_dir = root_dir @@ -66,7 +50,9 @@ def __init__( self.label_dict = label_dict if transforms is None: - self.transform = get_transform(augmentations=augmentations) + self.transform = get_transform( + augmentations=augmentations, data_keys=self._data_keys + ) else: if not isinstance(transforms, K.AugmentationSequential): raise ValueError( @@ -87,7 +73,7 @@ def __init__( for idx, _ in enumerate(self.image_names): self.image_dict[idx] = self.load_image(idx) - def _validate_labels(self): + def _validate_labels(self) -> None: """Validate that all labels in annotations exist in label_dict. Raises: @@ -102,7 +88,51 @@ def _validate_labels(self): f"Please ensure all labels in the annotations exist as keys in label_dict." ) - def _validate_coordinates(self): + @abstractmethod + def _validate_coordinates(self) -> None: + """Validate geometries in the annotation data. Must be overidden by + child classes to implement task-specific checks (e.g., boxes vs + points). + + Should raise ValueError with details if any invalid geometries + are found. + """ + + def __len__(self) -> int: + """Dataset length is the number of unique images.""" + return len(self.image_names) + + def collate_fn(self, batch) -> tuple: + """Collate function for DataLoader.""" + images = [item[0] for item in batch] + targets = [item[1] for item in batch] + image_names = [item[2] for item in batch] + + return images, targets, image_names + + def load_image(self, idx) -> np.typing.NDArray[np.float32]: + """Load image from disk and convert to float32 numpy array in [0, + 1].""" + img_name = os.path.join(self.root_dir, self.image_names[idx]) + image = np.array(Image.open(img_name).convert("RGB")) / 255 + image = image.astype("float32") + return image + + @abstractmethod + def annotations_for_path(self, image_path, return_tensor=False) -> Any: + """Construct target dictionary for a given image path, optionally + convert to tensor.""" + + @abstractmethod + def __getitem__(self, index) -> tuple: + """Return a single item from the dataset.""" + pass + + +class BoxDataset(TrainingDataset): + """Dataset for object detection with bounding boxes.""" + + def _validate_coordinates(self) -> None: """Validate that all bounding box coordinates occur within the image. Raises: @@ -152,7 +182,7 @@ def _validate_coordinates(self): if errors: raise ValueError("\n".join(errors)) - def filter_boxes(self, boxes, labels, image_shape, min_size=1): + def filter_boxes(self, boxes, labels, image_shape, min_size=1) -> tuple: """Clamp boxes to image bounds and filter by minimum dimension. Args: @@ -179,24 +209,7 @@ def filter_boxes(self, boxes, labels, image_shape, min_size=1): return boxes[valid_mask], labels[valid_mask] - def __len__(self): - return len(self.image_names) - - def collate_fn(self, batch): - """Collate function for DataLoader.""" - images = [item[0] for item in batch] - targets = [item[1] for item in batch] - image_names = [item[2] for item in batch] - - return images, targets, image_names - - def load_image(self, idx): - img_name = os.path.join(self.root_dir, self.image_names[idx]) - image = np.array(Image.open(img_name).convert("RGB")) / 255 - image = image.astype("float32") - return image - - def annotations_for_path(self, image_path, return_tensor=False): + def annotations_for_path(self, image_path, return_tensor=False) -> dict: """Construct target dictionary for a given image path, optionally convert to tensor. @@ -234,7 +247,7 @@ def annotations_for_path(self, image_path, return_tensor=False): return targets - def __getitem__(self, idx): + def __getitem__(self, idx) -> tuple: # Read image if not in memory if self.preload_images: image = self.image_dict[idx] @@ -243,17 +256,12 @@ def __getitem__(self, idx): targets = self.annotations_for_path(self.image_names[idx]) - # If image has no annotations, don't augment + # If image has no annotations, add a dummy if np.sum(targets["boxes"]) == 0: - boxes = torch.zeros((0, 4), dtype=torch.float32) - labels = torch.zeros(0, dtype=torch.int64) - # channels last - image = np.rollaxis(image, 2, 0) - image = torch.from_numpy(image).float() + boxes = np.zeros((0, 4), dtype=np.float32) + labels = np.zeros(0, dtype=np.int64) targets = {"boxes": boxes, "labels": labels} - return image, targets, self.image_names[idx] - # Apply augmentations image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() boxes_tensor = torch.from_numpy(targets["boxes"]).unsqueeze(0).float() @@ -277,6 +285,281 @@ def __getitem__(self, idx): return image, targets, self.image_names[idx] +class KeypointDataset(TrainingDataset): + """Dataset for keypoint detection tasks.""" + + _data_keys = [DataKey.IMAGE, DataKey.KEYPOINTS] + + def __init__( + self, + csv_file, + root_dir, + *, + transforms=None, + augmentations=None, + label_dict=None, + preload_images=False, + density_sigma=4.0, + output="centroid", + ): + """This dataset class returns keypoint annotations in one of two common + formats. + + If the output parameter is set to "centroid", each target is a dictionary with "points" + and "labels" entries. The "points" entry is a tensor of shape (N, 2) containing the xy + coordinates of each point, and the "labels" entry is a tensor of shape (N,) + containing the class label for each point. + + If the output parameter is set to "density", each target is a dictionary with only a "labels" entry. + In this case, "labels" is a tensor of shape (num_classes, H, W) representing a class-wise density map. + The map is count-normalized so that each channel sums to the number of points of that class. + + Args: + csv_file (str): Path to the CSV file containing annotations. + root_dir (str): Directory containing all referenced images. + transform (callable, optional): Function applied to each sample (e.g., image and target). Defaults to None. + label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}). + augmentations (str | list | dict, optional): Augmentation configuration. + preload_images (bool): If True, preload all images into memory. Defaults to False. + density_sigma (float): Standard deviation of the Gaussian kernel for density map generation. Defaults to 4.0. + output (str): Output format, either "centroid" for point coordinates or "density" for Gaussian density maps. Defaults to "centroid". + """ + super().__init__( + csv_file=csv_file, + root_dir=root_dir, + transforms=transforms, + augmentations=augmentations, + label_dict=label_dict, + preload_images=preload_images, + ) + + self.density_sigma = density_sigma + + if output not in ["centroid", "density"]: + raise ValueError( + f"Invalid output type: {output}. Supported options are 'centroid' and 'density'." + ) + self.output = output + + def _validate_coordinates(self) -> None: + """Validate that all points occur within the image. + + Raises: + ValueError: If any point occurs outside the image + """ + errors = [] + for _idx, row in self.annotations.iterrows(): + img_path = os.path.join(self.root_dir, row["image_path"]) + try: + with Image.open(img_path) as img: + width, height = img.size + except Exception as e: + errors.append(f"Failed to open image {img_path}: {e}") + continue + + # Extract point coordinates (use centroid so boxes/polygons also work) + centroid = row["geometry"].centroid + x, y = centroid.x, centroid.y + + # All coordinates equal to zero is how we code empty frames. + if x == 0 and y == 0: + continue + + # Check if point is valid + oob_issues = [] + if x < 0: + oob_issues.append(f"x ({x}) < 0") + if x > width: + oob_issues.append(f"x ({x}) > image width ({width})") + if y < 0: + oob_issues.append(f"y ({y}) < 0") + if y > height: + oob_issues.append(f"y ({y}) > image height ({height})") + + if oob_issues: + errors.append( + f"Point, ({x}, {y}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}." + ) + + if errors: + raise ValueError("\n".join(errors)) + + def filter_points(self, points, labels, image_shape) -> tuple: + """Filter points to be within the image. + + Args: + points (torch.Tensor): Points of shape (N, 2) in xy format. + labels (torch.Tensor): Labels of shape (N,). + image_shape (tuple): Image shape as (C, H, W). + + Returns: + tuple: A tuple of (filtered_points, filtered_labels) + """ + _, H, W = image_shape + + # Filter out of bounds + valid_mask = ( + (points[:, 0] >= 0) + & (points[:, 0] <= W) + & (points[:, 1] >= 0) + & (points[:, 1] <= H) + ) + + return points[valid_mask], labels[valid_mask] + + def annotations_for_path(self, image_path, return_tensor=False) -> dict: + """Construct target dictionary for a given image path, optionally + convert to tensor. + + Args: + image_path (str): Path to image, expected to be in dataframe + return_tensor (bool): If true, convert fields from numpy to tensor + + Returns: + target dictionary with points and labels entries + """ + image_annotations = self.annotations[self.annotations.image_path == image_path] + targets = {} + + if "geometry" in image_annotations.columns: + # Handle both shapely geometry objects and WKT strings + targets["points"] = np.array( + [ + x.centroid.coords[0] + if hasattr(x, "centroid") + else shapely.wkt.loads(x).centroid.coords[0] + for x in image_annotations.geometry + ] + ).astype("float32") + else: + targets["points"] = image_annotations[["x", "y"]].values.astype("float32") + + # Labels need to be encoded + targets["labels"] = image_annotations.label.apply( + lambda x: self.label_dict[x] + ).values.astype(np.int64) + + if return_tensor: + for k, v in targets.items(): + targets[k] = torch.from_numpy(v) + + return targets + + def gaussian_density(self, points, labels, shape) -> torch.Tensor: + """Convert points to a Gaussian density representation. + + Places a delta at each point location, applies a Gaussian blur with + sigma=density_sigma, then count-normalizes each class channel so that + channel.sum() == number of points of that class. + + Args: + points (torch.Tensor): Points of shape (N, 2) in xy format. + labels (torch.Tensor): Labels of shape (N,). + shape (tuple): Image shape as (C, H, W). + + Returns: + torch.Tensor: Density map of shape (num_classes, H, W) + """ + if len(shape) == 3: + _, H, W = shape + elif len(shape) == 2: + H, W = shape + else: + raise ValueError( + f"image_shape must be length 2 (H, W) or 3 (C, H, W), got {shape}." + ) + + # torchvision gaussian_blur expects (C, H, W) + num_classes = len(self.label_dict) + density = torch.zeros((num_classes, H, W), dtype=torch.float32) + + if len(points) == 0: + return density + + # Place a delta function for each point. + for point, label in zip(points, labels, strict=True): + r_x, r_y = round(point[0].item()), round(point[1].item()) + x, y = int(r_x), int(r_y) + class_index = int(label.item()) + if 0 <= x < W and 0 <= y < H: + density[class_index, y, x] = 1.0 + + # Apply Gaussian blur; kernel size chosen to cover +-3*sigma without clipping tails. + sigma = self.density_sigma + kernel_size = int(6 * sigma + 1) + if kernel_size % 2 == 0: + kernel_size += 1 + + density = torchvision.transforms.functional.gaussian_blur( + density, kernel_size=kernel_size, sigma=sigma + ) + + # Count-normalize each class channel so channel.sum() == num_points_in_class. + for cls_idx in range(density.shape[0]): + s = density[cls_idx].sum() + if s > 0: + n_cls = (labels == cls_idx).sum().item() + density[cls_idx] = density[cls_idx] * (n_cls / s) + + return density + + def __getitem__(self, idx) -> tuple: + """Returns a transformed data sample from the dataset. + + Returns: + image (torch.Tensor): Image tensor of shape (C, H, W). + targets (dict): Dictionary containing either: + - 'points': Tensor of shape (N, 2) with point coordinates in xy format. + - 'labels': Either a tensor of shape (N,) with class labels for 'centroid' output, + or a tensor of shape (num_classes, H, W) with density maps for 'density' output. + image_name (str): The name of the image corresponding to the returned data. + """ + # Read image if not in memory + if self.preload_images: + image = self.image_dict[idx] + else: + image = self.load_image(idx) + + targets = self.annotations_for_path(self.image_names[idx]) + + # Dummy annotations for empty image + if np.sum(targets["points"]) == 0: + targets = { + "points": np.zeros((0, 2), dtype=np.float32), + "labels": np.zeros(0, dtype=np.int64), + } + + # Apply augmentations + image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() + points_tensor = torch.from_numpy(targets["points"]).unsqueeze(0).float() + augmented_image, augmented_points = self.transform(image_tensor, points_tensor) + + # Convert to tensor + image = augmented_image.squeeze(0) + points = augmented_points.squeeze(0) + labels = torch.from_numpy(targets["labels"].astype(np.int64)) + + # Filter out-of-bounds points after augmentation + points, labels = self.filter_points(points, labels, image.shape) + + # Edge case if all labels were augmented away, keep the image + if len(points) == 0: + points = torch.zeros((0, 2), dtype=torch.float32) + labels = torch.zeros(0, dtype=torch.int64) + + if self.output == "density": + # Mask is NHW for N classes. + targets = {"labels": self.gaussian_density(points, labels, image.shape[1:])} + elif self.output == "centroid": + targets = {"points": points, "labels": labels} + else: + raise ValueError( + f"Invalid output type: {self.output}. Supported options are 'centroid' and 'density'." + ) + + return image, targets, self.image_names[idx] + + # ---------- ImageFolder alignment utilities ---------- diff --git a/src/deepforest/main.py b/src/deepforest/main.py index b421f89aa..ef38a042a 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -94,11 +94,13 @@ def setup_metrics(self): if not self.config.validation.csv_file: return - # Metrics - self.iou_metric = IntersectionOverUnion( - class_metrics=True, iou_threshold=self.config.validation.iou_threshold - ) - self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") + # Box Metrics + if self.model.task == "box": + self.iou_metric = IntersectionOverUnion( + class_metrics=True, iou_threshold=self.config.validation.iou_threshold + ) + + self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") self.precision_recall_metric = RecallPrecision( iou_threshold=self.config.validation.iou_threshold, @@ -329,14 +331,29 @@ def load_dataset( ds: a pytorch dataset """ - ds = training.BoxDataset( - csv_file=csv_file, - root_dir=root_dir, - transforms=transforms, - label_dict=self.label_dict, - augmentations=augmentations, - preload_images=preload_images, - ) + if self.model.task == "box": + ds = training.BoxDataset( + csv_file=csv_file, + root_dir=root_dir, + transforms=transforms, + label_dict=self.label_dict, + augmentations=augmentations, + preload_images=preload_images, + ) + elif self.model.task == "keypoint": + ds = training.KeypointDataset( + csv_file=csv_file, + root_dir=root_dir, + transforms=transforms, + label_dict=self.label_dict, + augmentations=augmentations, + preload_images=preload_images, + ) + else: + raise ValueError( + f"Invalid task type: {self.model.task}, expected 'box' or 'keypoint'" + ) + if len(ds) == 0: raise ValueError( f"Dataset from {csv_file} is empty. Check CSV for valid entries and columns." diff --git a/src/deepforest/models/DeformableDetr.py b/src/deepforest/models/DeformableDetr.py index 45d3ec17b..52aa686a7 100644 --- a/src/deepforest/models/DeformableDetr.py +++ b/src/deepforest/models/DeformableDetr.py @@ -20,6 +20,8 @@ class DeformableDetrWrapper(nn.Module): """This class wraps a transformers DeformableDetrForObjectDetection model so that input pre- and post-processing happens transparently.""" + task: str = "box" + def __init__(self, config, name, revision, use_nms=False, **hf_args): """Initialize a DeformableDetrForObjectDetection model. diff --git a/src/deepforest/models/retinanet.py b/src/deepforest/models/retinanet.py index cfb114318..5e1389f53 100644 --- a/src/deepforest/models/retinanet.py +++ b/src/deepforest/models/retinanet.py @@ -12,6 +12,8 @@ class RetinaNetHub(RetinaNet, PyTorchModelHubMixin): """RetinaNet extension that allows the use of the HF Hub API.""" + task: str = "box" + def __init__( self, backbone_weights: str | None = None, diff --git a/tests/test_datasets_training.py b/tests/test_datasets_training_boxes.py similarity index 100% rename from tests/test_datasets_training.py rename to tests/test_datasets_training_boxes.py diff --git a/tests/test_datasets_training_keypoint.py b/tests/test_datasets_training_keypoint.py new file mode 100644 index 000000000..cd16fdecb --- /dev/null +++ b/tests/test_datasets_training_keypoint.py @@ -0,0 +1,271 @@ +"""Tests for KeypointDataset.""" + +import os + +import numpy as np +import pandas as pd +import pytest +import torch + +from deepforest import get_data +from deepforest.datasets.training import KeypointDataset + + +@pytest.fixture() +def keypoint_csv(): + return get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") + + +@pytest.fixture() +def keypoint_root_dir(): + return os.path.dirname( + get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") + ) + + +@pytest.fixture() +def box_csv(): + """Bounding box CSV to test centroid conversion.""" + return get_data("example.csv") + + +@pytest.fixture() +def box_root_dir(): + return os.path.dirname(get_data("OSBS_029.png")) + + +def test_keypoint_dataset_centroid(keypoint_csv, keypoint_root_dir): + """Basic construction, iteration, and output format.""" + ds = KeypointDataset( + csv_file=keypoint_csv, root_dir=keypoint_root_dir, label_dict={"Tree": 0}, output="centroid" + ) + raw = pd.read_csv(keypoint_csv) + + assert len(ds) == len(raw.image_path.unique()) + + for i in range(len(ds)): + image, targets, _path = ds[i] + + # Image: channels-first float tensor in [0, 1] + assert torch.is_tensor(image) + assert image.shape[0] == 3 + assert image.min() >= 0 + assert image.max() <= 1 + + # Targets: correct shapes, types, and dtypes + assert targets["points"].shape == (raw.shape[0], 2) + assert targets["points"].dtype == torch.float32 + + assert targets["labels"].shape == (raw.shape[0],) + assert targets["labels"].dtype == torch.int64 + +def test_keypoint_dataset_density(keypoint_csv, keypoint_root_dir): + """Density output mode should return a class-first tensor.""" + ds = KeypointDataset( + csv_file=keypoint_csv, + root_dir=keypoint_root_dir, + label_dict={"Tree": 0}, + output="density", + ) + + image, targets, _ = ds[0] + + assert torch.is_tensor(image) + assert "labels" in targets + assert "points" not in targets + assert targets["labels"].ndim == 3 + assert targets["labels"].shape[0] == 1 + assert targets["labels"].shape[1:] == image.shape[1:] + assert targets["labels"].dtype == torch.float32 + assert targets["labels"].max() > 0 + + +def test_keypoint_dataset_from_boxes(box_csv, box_root_dir): + """When given bounding box geometry, annotations_for_path should extract centroids.""" + ds = KeypointDataset( + csv_file=box_csv, root_dir=box_root_dir, label_dict={"Tree": 0} + ) + raw = pd.read_csv(box_csv) + + _image, targets, _path = next(iter(ds)) + points = targets["points"] + + assert points.shape == (raw.shape[0], 2) + + # Verify raw annotations without augmentation + targets = ds.annotations_for_path(ds.image_names[0]) + for i, (_, row) in enumerate(raw.iterrows()): + expected_cx = (row["xmin"] + row["xmax"]) / 2 + expected_cy = (row["ymin"] + row["ymax"]) / 2 + np.testing.assert_allclose( + targets["points"][i], [expected_cx, expected_cy], atol=0.01 + ) + + +def test_keypoint_dataset_hflip(keypoint_csv, keypoint_root_dir): + """Test that augmentation works by performing a horizontal flip augmentation, + checking it correctly flips x coordinates and leaves y unchanged.""" + ds_orig = KeypointDataset( + csv_file=keypoint_csv, root_dir=keypoint_root_dir, + ) + ds_flip = KeypointDataset( + csv_file=keypoint_csv, root_dir=keypoint_root_dir, + augmentations=[{"HorizontalFlip": {"p": 1.0}}], + ) + + _, targets_orig, _ = ds_orig[0] + _, targets_flip, _ = ds_flip[0] + W = ds_orig.load_image(0).shape[1] + + # Flipped x should be approximately (W - original_x) + torch.testing.assert_close( + targets_flip["points"][:, 0], + W - targets_orig["points"][:, 0], + atol=1.0, rtol=0, + ) + + +def test_keypoint_dataset_validate_coordinates_oob(tmp_path, keypoint_root_dir): + """Out-of-bounds points should raise ValueError.""" + image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" + + csv_path = str(tmp_path / "oob.csv") + df = pd.DataFrame( + { + "image_path": [image_name], + "x": [2000], # image is 1024x1024 + "y": [500], + "label": ["Tree"], + } + ) + df.to_csv(csv_path, index=False) + + with pytest.raises(ValueError, match="exceeds image dimensions"): + KeypointDataset( + csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + ) + + +def test_keypoint_dataset_validate_coordinates_negative(tmp_path, keypoint_root_dir): + """Negative coordinates should raise ValueError.""" + image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" + + csv_path = str(tmp_path / "neg.csv") + df = pd.DataFrame( + { + "image_path": [image_name], + "x": [-10], + "y": [500], + "label": ["Tree"], + } + ) + df.to_csv(csv_path, index=False) + + with pytest.raises(ValueError, match="exceeds image dimensions"): + KeypointDataset( + csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + ) + + +def test_keypoint_dataset_empty_annotations(tmp_path, keypoint_root_dir): + """Empty annotations (0,0) should produce empty targets.""" + image_name = "2019_BLAN_3_751000_4330000_image_crop.jpg" + + csv_path = str(tmp_path / "empty.csv") + df = pd.DataFrame( + { + "image_path": [image_name], + "x": [0], + "y": [0], + "label": ["Tree"], + } + ) + df.to_csv(csv_path, index=False) + + ds = KeypointDataset( + csv_file=csv_path, root_dir=keypoint_root_dir, label_dict={"Tree": 0} + ) + image, targets, path = ds[0] + assert targets["points"].shape == (0, 2) + assert targets["labels"].shape == (0,) + + +def test_keypoint_dataset_filter_points(): + """filter_points should remove out-of-bounds points.""" + ds_csv = get_data("2019_BLAN_3_751000_4330000_image_crop_keypoints.csv") + root_dir = os.path.dirname(ds_csv) + ds = KeypointDataset(csv_file=ds_csv, root_dir=root_dir, label_dict={"Tree": 0}) + + points = torch.tensor([[10.0, 20.0], [-5.0, 30.0], [50.0, 60.0], [200.0, 300.0]]) + labels = torch.tensor([0, 0, 0, 0]) + image_shape = (3, 100, 100) # H=100, W=100 + + filtered_points, filtered_labels = ds.filter_points(points, labels, image_shape) + + assert filtered_points.shape[0] == 2 # only (10,20) and (50,60) are in bounds + assert filtered_labels.shape[0] == 2 + torch.testing.assert_close( + filtered_points, torch.tensor([[10.0, 20.0], [50.0, 60.0]]) + ) + + +def test_keypoint_dataset_density_map(keypoint_csv, keypoint_root_dir): + """Density map should place class-specific peaks at point locations.""" + ds = KeypointDataset( + csv_file=keypoint_csv, + root_dir=keypoint_root_dir, + label_dict={"Tree": 0, "Shrub": 1}, + density_sigma=2, + output="density", + ) + + points = torch.tensor([[25.0, 40.0], [70.0, 15.0]], dtype=torch.float32) + labels = torch.tensor([0, 1], dtype=torch.int64) + density = ds.gaussian_density(points, labels, (3, 100, 100)) + + assert density.shape == (2, 100, 100) + assert density.dtype == torch.float32 + assert torch.argmax(density[0]).item() == (40 * 100 + 25) + assert torch.argmax(density[1]).item() == (15 * 100 + 70) + + +def test_keypoint_dataset_density_ignores_oob_points(keypoint_csv, keypoint_root_dir): + """Out-of-bounds points should not contribute to density map.""" + ds = KeypointDataset( + csv_file=keypoint_csv, + root_dir=keypoint_root_dir, + label_dict={"Tree": 0}, + density_sigma=2, + output="density", + ) + + points = torch.tensor([[-8.0, 40.0], [16.0, 16.0], [40.0, 160.0]], dtype=torch.float32) + labels = torch.tensor([0, 0, 0], dtype=torch.int64) + density = ds.gaussian_density(points, labels, (3, 64, 64)) + + # Only the point [16.0, 16.0] is in bounds for a 64x64 image + assert density.shape == (1, 64, 64) + # Peak should be near position (16, 16) when flattened + peak_pos = torch.argmax(density[0]).item() + peak_y = peak_pos // 64 + peak_x = peak_pos % 64 + # Allow small tolerance due to Gaussian blur spreading + assert abs(peak_x - 16.0) <= 1 + assert abs(peak_y - 16.0) <= 1 + + +def test_gaussian_density_count_normalization(keypoint_csv, keypoint_root_dir): + """Each class channel of the density map should sum to the number of points in that class.""" + ds = KeypointDataset( + csv_file=keypoint_csv, + root_dir=keypoint_root_dir, + label_dict={"Tree": 0, "Shrub": 1}, + output="density", + ) + + points = torch.tensor([[50.0, 50.0], [60.0, 60.0], [70.0, 70.0]], dtype=torch.float32) + labels = torch.tensor([0, 0, 1], dtype=torch.int64) # 2 Trees, 1 Shrub + density = ds.gaussian_density(points, labels, (3, 200, 200)) + + assert density[0].sum().item() == pytest.approx(2.0, abs=1e-3) + assert density[1].sum().item() == pytest.approx(1.0, abs=1e-3)