Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/deepforest/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ train:
fast_dev_run: False
# preload images to GPU memory for fast training. This depends on GPU size and number of images.
preload_images: False
# Skip per-image dimension validation when all images share the same size.
same_size_images: False


validation:
csv_file:
root_dir:
preload_images: False
same_size_images: False
size:

# For retinanet you may prefer val_classification, but the default val_loss
Expand Down
2 changes: 2 additions & 0 deletions src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class TrainConfig:
epochs: int = 1
fast_dev_run: bool = False
preload_images: bool = False
same_size_images: bool = False
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])


Expand All @@ -83,6 +84,7 @@ class ValidationConfig:
csv_file: str | None = None
root_dir: str | None = None
preload_images: bool = False
same_size_images: bool = False
size: int | None = None
iou_threshold: float = 0.4
val_accuracy_interval: int = 20
Expand Down
104 changes: 66 additions & 38 deletions src/deepforest/datasets/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
augmentations=None,
label_dict=None,
preload_images=False,
same_size_images=False,
):
"""
Args:
Expand All @@ -40,6 +41,8 @@ def __init__(
label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}).
augmentations (str | list | dict, optional): Augmentation configuration.
preload_images (bool): If True, preload all images into memory. Defaults to False.
same_size_images (bool): If True, skip per-image validation by assuming all images share
the same dimensions as the first image. Defaults to False.
"""
self.annotations = utilities.read_file(csv_file, root_dir=root_dir)
self.root_dir = root_dir
Expand All @@ -62,6 +65,7 @@ def __init__(

self.image_names = self.annotations.image_path.unique()
self.preload_images = preload_images
self.same_size_images = same_size_images

self._validate_labels()
self._validate_coordinates()
Expand Down Expand Up @@ -139,17 +143,28 @@ def _validate_coordinates(self) -> None:
ValueError: If any bounding box coordinate occurs outside the image
"""
errors = []
for image_path, group in self.annotations.groupby("image_path"):
img_path = os.path.join(self.root_dir, image_path)

if self.same_size_images:
first_path = os.path.join(self.root_dir, self.image_names[0])
try:
with Image.open(img_path) as img:
width, height = img.size
with Image.open(first_path) as img:
shared_size = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue
raise ValueError(f"Failed to open image {first_path}: {e}") from e

for image_path, group in self.annotations.groupby("image_path"):
if self.same_size_images:
width, height = shared_size
else:
img_path = os.path.join(self.root_dir, image_path)
try:
with Image.open(img_path) as img:
width, height = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

for _idx, row in group.iterrows():
# Extract bounding box
geom = row["geometry"]
xmin, ymin, xmax, ymax = geom.bounds

Expand All @@ -159,7 +174,6 @@ def _validate_coordinates(self) -> None:
if xmin == 0 and ymin == 0 and xmax == 0 and ymax == 0:
continue

# Check if box is valid
oob_issues = []
if not geom.equals(shapely.envelope(geom)):
oob_issues.append(f"geom ({geom}) is not a valid bounding box")
Expand Down Expand Up @@ -299,6 +313,7 @@ def __init__(
augmentations=None,
label_dict=None,
preload_images=False,
same_size_images=False,
density_sigma=4.0,
output="centroid",
):
Expand All @@ -321,6 +336,8 @@ def __init__(
label_dict (dict[str, int]): Mapping from string labels in the CSV to integer class IDs (e.g., {"Tree": 0}).
augmentations (str | list | dict, optional): Augmentation configuration.
preload_images (bool): If True, preload all images into memory. Defaults to False.
same_size_images (bool): If True, skip per-image validation by assuming all images share
the same dimensions as the first image. Defaults to False.
density_sigma (float): Standard deviation of the Gaussian kernel for density map generation. Defaults to 4.0.
output (str): Output format, either "centroid" for point coordinates or "density" for Gaussian density maps. Defaults to "centroid".
"""
Expand All @@ -331,6 +348,7 @@ def __init__(
augmentations=augmentations,
label_dict=label_dict,
preload_images=preload_images,
same_size_images=same_size_images,
)

self.density_sigma = density_sigma
Expand All @@ -348,38 +366,48 @@ def _validate_coordinates(self) -> None:
ValueError: If any point occurs outside the image
"""
errors = []
for _idx, row in self.annotations.iterrows():
img_path = os.path.join(self.root_dir, row["image_path"])

if self.same_size_images:
first_path = os.path.join(self.root_dir, self.image_names[0])
try:
with Image.open(img_path) as img:
width, height = img.size
with Image.open(first_path) as img:
shared_size = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

# Extract point coordinates (use centroid so boxes/polygons also work)
centroid = row["geometry"].centroid
x, y = centroid.x, centroid.y

# All coordinates equal to zero is how we code empty frames.
if x == 0 and y == 0:
continue

# Check if point is valid
oob_issues = []
if x < 0:
oob_issues.append(f"x ({x}) < 0")
if x > width:
oob_issues.append(f"x ({x}) > image width ({width})")
if y < 0:
oob_issues.append(f"y ({y}) < 0")
if y > height:
oob_issues.append(f"y ({y}) > image height ({height})")

if oob_issues:
errors.append(
f"Point, ({x}, {y}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}."
)
raise ValueError(f"Failed to open image {first_path}: {e}") from e

for image_path, group in self.annotations.groupby("image_path"):
if self.same_size_images:
width, height = shared_size
else:
img_path = os.path.join(self.root_dir, image_path)
try:
with Image.open(img_path) as img:
width, height = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

for _idx, row in group.iterrows():
centroid = row["geometry"].centroid
x, y = centroid.x, centroid.y

if x == 0 and y == 0:
continue

oob_issues = []
if x < 0:
oob_issues.append(f"x ({x}) < 0")
if x > width:
oob_issues.append(f"x ({x}) > image width ({width})")
if y < 0:
oob_issues.append(f"y ({y}) < 0")
if y > height:
oob_issues.append(f"y ({y}) > image height ({height})")

if oob_issues:
errors.append(
f"Point, ({x}, {y}) exceeds image dimensions, ({width}, {height}). Issues: {', '.join(oob_issues)}."
)

if errors:
raise ValueError("\n".join(errors))
Expand Down
6 changes: 6 additions & 0 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def load_dataset(
transforms=None,
augmentations=None,
preload_images=False,
same_size_images=False,
batch_size=1,
):
"""Create a dataset for inference or training. Csv file format is .csv
Expand All @@ -339,6 +340,7 @@ def load_dataset(
batch_size: batch size
preload_images: if True, preload the images into memory
augmentations: augmentation configuration (str, list, or dict)
same_size_images: if True, skip per-image validation by assuming all images share the same dimensions
Returns:
ds: a pytorch dataset
"""
Expand All @@ -351,6 +353,7 @@ def load_dataset(
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
same_size_images=same_size_images,
)
elif self.model.task == "keypoint":
ds = training.KeypointDataset(
Expand All @@ -360,6 +363,7 @@ def load_dataset(
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
same_size_images=same_size_images,
)
else:
raise ValueError(
Expand Down Expand Up @@ -395,6 +399,7 @@ def train_dataloader(self):
root_dir=self.config.train.root_dir,
augmentations=self.config.train.augmentations,
preload_images=self.config.train.preload_images,
same_size_images=self.config.train.same_size_images,
shuffle=True,
transforms=self.transforms,
batch_size=self.config.batch_size,
Expand Down Expand Up @@ -422,6 +427,7 @@ def val_dataloader(self):
augmentations=self.config.validation.augmentations,
shuffle=False,
preload_images=self.config.validation.preload_images,
same_size_images=self.config.validation.same_size_images,
batch_size=self.config.batch_size,
)

Expand Down
Loading