diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 029a4960e..4cd6aba83 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -82,7 +82,7 @@ class TrainConfig: fast_dev_run: bool = False preload_images: bool = False validate_coordinates: bool = True - augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"]) + augmentations: list[Any] | None = field(default_factory=lambda: ["HorizontalFlip"]) @dataclass @@ -104,7 +104,7 @@ class ValidationConfig: iou_threshold: float = 0.4 val_accuracy_interval: int = 20 lr_plateau_target: str = "val_loss" - augmentations: list[str] | None = field(default_factory=lambda: []) + augmentations: list[Any] | None = field(default_factory=lambda: []) @dataclass diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 57e391234..ff9a2d861 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -41,6 +41,34 @@ def test_load_dataset_without_augmentations(): image, target, path = next(iter(train_ds)) assert len(train_ds.dataset.transform) == 0 +def test_augmentation_schema_validation(): + """ + Test that the schema accepts a mixed list of dictionaries and strings for augmentations, + and that they are correctly applied to the dataset pipeline. + """ + augmentations = [ + {"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}}, + "HorizontalFlip" + ] + + m = main.deepforest(config_args={"train": {"augmentations": augmentations}}) + + # Verify Schema stored it correctly + assert m.config.train.augmentations == augmentations + + csv_file = get_data("example.csv") + root_dir = os.path.dirname(csv_file) + + train_ds = m.load_dataset(csv_file, root_dir=root_dir, augmentations=augmentations) + + transforms = train_ds.dataset.transform + + has_resized_crop = any(isinstance(t, K.RandomResizedCrop) for t in transforms) + has_hflip = any(isinstance(t, K.RandomHorizontalFlip) for t in transforms) + + assert has_resized_crop + assert has_hflip + """ Augmentation parsing tests: """