diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index d42ac11f0..2e0fb0289 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -83,12 +83,15 @@ train: fast_dev_run: False # preload images to GPU memory for fast training. This depends on GPU size and number of images. preload_images: False + # Skip per-image dimension validation when all images share the same size. + same_size_images: False validation: csv_file: root_dir: preload_images: False + same_size_images: False size: # For retinanet you may prefer val_classification, but the default val_loss diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index b593ecc3c..a7abf44aa 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -68,6 +68,7 @@ class TrainConfig: epochs: int = 1 fast_dev_run: bool = False preload_images: bool = False + same_size_images: bool = False augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"]) @@ -83,6 +84,7 @@ class ValidationConfig: csv_file: str | None = None root_dir: str | None = None preload_images: bool = False + same_size_images: bool = False size: int | None = None iou_threshold: float = 0.4 val_accuracy_interval: int = 20 diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 7ed188f45..8457c6deb 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -31,6 +31,7 @@ def __init__( augmentations=None, label_dict=None, preload_images=False, + same_size_images=False, ): """ Args: @@ -40,6 +41,8 @@ def __init__( label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}). augmentations (str | list | dict, optional): Augmentation configuration. preload_images (bool): If True, preload all images into memory. Defaults to False. + same_size_images (bool): If True, skip per-image validation by assuming all images share + the same dimensions as the first image. Defaults to False. """ self.annotations = utilities.read_file(csv_file, root_dir=root_dir) self.root_dir = root_dir @@ -62,6 +65,7 @@ def __init__( self.image_names = self.annotations.image_path.unique() self.preload_images = preload_images + self.same_size_images = same_size_images self._validate_labels() self._validate_coordinates() @@ -139,17 +143,28 @@ def _validate_coordinates(self) -> None: ValueError: If any bounding box coordinate occurs outside the image """ errors = [] - for image_path, group in self.annotations.groupby("image_path"): - img_path = os.path.join(self.root_dir, image_path) + + if self.same_size_images: + first_path = os.path.join(self.root_dir, self.image_names[0]) try: - with Image.open(img_path) as img: - width, height = img.size + with Image.open(first_path) as img: + shared_size = img.size except Exception as e: - errors.append(f"Failed to open image {img_path}: {e}") - continue + raise ValueError(f"Failed to open image {first_path}: {e}") from e + + for image_path, group in self.annotations.groupby("image_path"): + if self.same_size_images: + width, height = shared_size + else: + img_path = os.path.join(self.root_dir, image_path) + try: + with Image.open(img_path) as img: + width, height = img.size + except Exception as e: + errors.append(f"Failed to open image {img_path}: {e}") + continue for _idx, row in group.iterrows(): - # Extract bounding box geom = row["geometry"] xmin, ymin, xmax, ymax = geom.bounds @@ -159,7 +174,6 @@ def _validate_coordinates(self) -> None: if xmin == 0 and ymin == 0 and xmax == 0 and ymax == 0: continue - # Check if box is valid oob_issues = [] if not geom.equals(shapely.envelope(geom)): oob_issues.append(f"geom ({geom}) is not a valid bounding box") @@ -299,6 +313,7 @@ def __init__( augmentations=None, label_dict=None, preload_images=False, + same_size_images=False, density_sigma=4.0, output="centroid", ): @@ -321,6 +336,8 @@ def __init__( label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}). augmentations (str | list | dict, optional): Augmentation configuration. preload_images (bool): If True, preload all images into memory. Defaults to False. + same_size_images (bool): If True, skip per-image validation by assuming all images share + the same dimensions as the first image. Defaults to False. density_sigma (float): Standard deviation of the Gaussian kernel for density map generation. Defaults to 4.0. output (str): Output format, either "centroid" for point coordinates or "density" for Gaussian density maps. Defaults to "centroid". """ @@ -331,6 +348,7 @@ def __init__( augmentations=augmentations, label_dict=label_dict, preload_images=preload_images, + same_size_images=same_size_images, ) self.density_sigma = density_sigma @@ -348,38 +366,48 @@ def _validate_coordinates(self) -> None: ValueError: If any point occurs outside the image """ errors = [] - for _idx, row in self.annotations.iterrows(): - img_path = os.path.join(self.root_dir, row["image_path"]) + + if self.same_size_images: + first_path = os.path.join(self.root_dir, self.image_names[0]) try: - with Image.open(img_path) as img: - width, height = img.size + with Image.open(first_path) as img: + shared_size = img.size except Exception as e: - errors.append(f"Failed to open image {img_path}: {e}") - continue - - # Extract point coordinates (use centroid so boxes/polygons also work) - centroid = row["geometry"].centroid - x, y = centroid.x, centroid.y - - # All coordinates equal to zero is how we code empty frames. - if x == 0 and y == 0: - continue - - # Check if point is valid - oob_issues = [] - if x < 0: - oob_issues.append(f"x ({x}) < 0") - if x > width: - oob_issues.append(f"x ({x}) > image width ({width})") - if y < 0: - oob_issues.append(f"y ({y}) < 0") - if y > height: - oob_issues.append(f"y ({y}) > image height ({height})") - - if oob_issues: - errors.append( - f"Point, ({x}, {y}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}." - ) + raise ValueError(f"Failed to open image {first_path}: {e}") from e + + for image_path, group in self.annotations.groupby("image_path"): + if self.same_size_images: + width, height = shared_size + else: + img_path = os.path.join(self.root_dir, image_path) + try: + with Image.open(img_path) as img: + width, height = img.size + except Exception as e: + errors.append(f"Failed to open image {img_path}: {e}") + continue + + for _idx, row in group.iterrows(): + centroid = row["geometry"].centroid + x, y = centroid.x, centroid.y + + if x == 0 and y == 0: + continue + + oob_issues = [] + if x < 0: + oob_issues.append(f"x ({x}) < 0") + if x > width: + oob_issues.append(f"x ({x}) > image width ({width})") + if y < 0: + oob_issues.append(f"y ({y}) < 0") + if y > height: + oob_issues.append(f"y ({y}) > image height ({height})") + + if oob_issues: + errors.append( + f"Point, ({x}, {y}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}." + ) if errors: raise ValueError("\n".join(errors)) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index e121adf2a..bcde4606b 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -324,6 +324,7 @@ def load_dataset( transforms=None, augmentations=None, preload_images=False, + same_size_images=False, batch_size=1, ): """Create a dataset for inference or training. Csv file format is .csv @@ -339,6 +340,7 @@ def load_dataset( batch_size: batch size preload_images: if True, preload the images into memory augmentations: augmentation configuration (str, list, or dict) + same_size_images: if True, skip per-image validation by assuming all images share the same dimensions Returns: ds: a pytorch dataset """ @@ -351,6 +353,7 @@ def load_dataset( label_dict=self.label_dict, augmentations=augmentations, preload_images=preload_images, + same_size_images=same_size_images, ) elif self.model.task == "keypoint": ds = training.KeypointDataset( @@ -360,6 +363,7 @@ def load_dataset( label_dict=self.label_dict, augmentations=augmentations, preload_images=preload_images, + same_size_images=same_size_images, ) else: raise ValueError( @@ -395,6 +399,7 @@ def train_dataloader(self): root_dir=self.config.train.root_dir, augmentations=self.config.train.augmentations, preload_images=self.config.train.preload_images, + same_size_images=self.config.train.same_size_images, shuffle=True, transforms=self.transforms, batch_size=self.config.batch_size, @@ -422,6 +427,7 @@ def val_dataloader(self): augmentations=self.config.validation.augmentations, shuffle=False, preload_images=self.config.validation.preload_images, + same_size_images=self.config.validation.same_size_images, batch_size=self.config.batch_size, )