diff --git a/docs/_static/metadata_prior_example.png b/docs/_static/metadata_prior_example.png new file mode 100644 index 000000000..f6eca4fd6 Binary files /dev/null and b/docs/_static/metadata_prior_example.png differ diff --git a/docs/user_guide/03_cropmodels.md b/docs/user_guide/03_cropmodels.md index be17c2f2a..ff1bc088c 100644 --- a/docs/user_guide/03_cropmodels.md +++ b/docs/user_guide/03_cropmodels.md @@ -18,10 +18,147 @@ While that approach is certainly valid, there are a few key benefits to using Cr - **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates. - **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines. +(spatial-temporal-metadata)= +## Spatial-Temporal Metadata + +In biodiversity monitoring, species distributions vary by location and season. A bird common in Florida may be rare in Alaska, and migratory species shift seasonally. The CropModel supports an optional spatial-temporal metadata embedding that provides location and date context alongside image features to improve classification. + +The metadata signal is intentionally "gentle" — it contributes only ~1.5% of the feature vector (32 dimensions vs. 2048 image features). This means the model still classifies primarily from visual appearance but can use location/season as a soft prior. When metadata is not provided at inference time, the model gracefully degrades to image-only classification. + +### How It Works + +When `use_metadata=True`, the CropModel: + +1. Encodes `(lat, lon, day_of_year)` using sinusoidal features (smooth, periodic representation) +2. Projects the 6 sinusoidal features through a small MLP to a 32-dim embedding +3. Concatenates this with the 2048-dim ResNet image features +4. Classifies from the combined 2080-dim vector + +### Inference with Metadata + +Pass a `metadata` dict to `predict_tile`: + +```python +from deepforest import main +from deepforest.model import CropModel + +m = main.deepforest() +m.create_trainer() + +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk(train_dir="path/to/train", val_dir="path/to/val", + metadata_csv="metadata.csv") +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) + +result = m.predict_tile( + path="image.tif", + crop_model=crop_model, + metadata={"lat": 35.2, "lon": -120.4, "date": "2024-06-15"} +) +``` + +All detected crops in the tile share the same metadata. If `metadata` is omitted, the model falls back to image-only classification. + +### Training with Metadata + +Training requires a CSV sidecar file that maps each crop image filename to its spatial-temporal metadata: + +```text +filename,lat,lon,date +bird_001.png,35.2,-120.4,2024-06-15 +bird_002.png,35.2,-120.4,2024-06-15 +mammal_001.png,40.1,-105.3,2024-07-20 +``` + +- `filename` matches the image basename inside the ImageFolder class directories +- `date` is an ISO format string, converted to day-of-year internally +- One CSV covers both train and val sets (filenames are unique) + +The existing ImageFolder directory structure is unchanged: + +``` +train/ + Bird/ + bird_001.png + bird_002.png + Mammal/ + mammal_001.png +``` + +Pass the CSV when loading data: + +```python +from deepforest.model import CropModel + +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk( + train_dir="path/to/train", + val_dir="path/to/val", + metadata_csv="metadata.csv" +) +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) +``` + +### Configuration + +The metadata embedding is controlled by three config parameters: + +```python +crop_model = CropModel(config_args={ + "use_metadata": True, # Enable metadata fusion (default: False) + "metadata_dim": 32, # Embedding dimension (default: 32) + "metadata_dropout": 0.5, # Dropout on metadata path (default: 0.5) +}) +``` + +Or in `config.yaml`: + +```yaml +cropmodel: + use_metadata: True + metadata_dim: 32 + metadata_dropout: 0.5 +``` + +### Visualizing Metadata Priors + +After training a metadata-enabled CropModel, it can be useful to inspect the +spatial-temporal branch by itself. The +{download}`metadata prior visualization script ` +loads a checkpoint, evaluates a lat/lon grid for one or more dates, and writes: + +- A CSV with metadata-only logits, probabilities, and relative scores +- PNG maps for selected species and dates +- GeoTIFF rasters for GIS workflows + +For example: + +```bash +uv run python docs/user_guide/examples/visualize_metadata_priors.py \ + --checkpoint path/to/metadata_cropmodel.ckpt \ + --species "Morus bassanus" \ + --dates 2024-04-15 \ + --bounds -98 18 -55 48 \ + --cell-degrees 1.0 \ + --output-dir outputs/metadata_prior_maps +``` + +The map below shows a relative metadata prior for Northern Gannet +(`Morus bassanus`) on April 15, 2024. It reflects the learned metadata branch, +not image evidence. Basemap tiles are optional; install `contextily` to include +them or pass `--no-basemap` to plot only the score raster. + +```{image} ../_static/metadata_prior_example.png +:alt: Metadata prior map for Morus bassanus over the western Atlantic +:width: 650px +``` + ## Considerations - **Efficiency**: Using a CropModel will be slower, as for each detection, the sensor data needs to be cropped and passed to the detector. This is less efficient than using a combined classification/detection system like multi-class detection models. While modern GPUs mitigate this to some extent, it is still something to be mindful of. -- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. +- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. See the {ref}`spatial-temporal-metadata` section for an optional way to incorporate location and season information. ## Single Crop Model diff --git a/docs/user_guide/09_configuration_file.md b/docs/user_guide/09_configuration_file.md index bda0de7ba..c655b5f77 100644 --- a/docs/user_guide/09_configuration_file.md +++ b/docs/user_guide/09_configuration_file.md @@ -319,3 +319,15 @@ crop_model = CropModel() # Or use custom resize dimensions crop_model = CropModel(config_args={"resize": [300, 300]}) ``` + +### use_metadata + +Boolean flag to enable spatial-temporal metadata fusion. When `True`, the model accepts `(lat, lon, date)` alongside image crops and learns a small embedding that is concatenated with image features. Default is `False`. See {ref}`spatial-temporal-metadata` for usage details. + +### metadata_dim + +Dimension of the metadata embedding vector. A smaller value makes the metadata signal more gentle relative to the 2048-dim image features. Default is `32`. + +### metadata_dropout + +Dropout rate applied to the metadata embedding path. Higher values reduce the model's reliance on location/date information. Default is `0.5`. diff --git a/docs/user_guide/examples/visualize_metadata_priors.py b/docs/user_guide/examples/visualize_metadata_priors.py new file mode 100644 index 000000000..14a6d108d --- /dev/null +++ b/docs/user_guide/examples/visualize_metadata_priors.py @@ -0,0 +1,327 @@ +"""Map metadata-only class priors from a metadata-enabled CropModel checkpoint. + +This script visualizes what the spatial-temporal embedding branch contributes +to each class, independent of image content. It evaluates a coarse lat/lon grid +for one or more dates, then writes CSV score rasters and PNG maps. +""" + +from __future__ import annotations + +import argparse +import datetime as dt +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from rasterio.transform import from_origin + +from deepforest.model import CropModel + +try: + import contextily as ctx +except ImportError: # pragma: no cover - contextily is an optional visual enhancement. + ctx = None + + +SPECIES_ALIASES = { + "Northern Gannet": "Morus bassanus", + "Common Eider": "Somateria mollissima", +} + +DEFAULT_SPECIES = ["Morus bassanus", "Somateria mollissima"] +DEFAULT_DATES = ["2024-01-15", "2024-04-15", "2024-07-15", "2024-10-15"] +DEFAULT_BOUNDS = (-98.0, 18.0, -55.0, 48.0) # Gulf of Mexico + western Atlantic + + +def day_of_year(date: str) -> float: + return float(dt.datetime.strptime(date, "%Y-%m-%d").timetuple().tm_yday) + + +def resolve_species(species: list[str]) -> list[str]: + return [SPECIES_ALIASES.get(name, name) for name in species] + + +def make_grid( + bounds: tuple[float, float, float, float], cell_degrees: float +) -> pd.DataFrame: + min_lon, min_lat, max_lon, max_lat = bounds + lons = np.arange(min_lon + cell_degrees / 2, max_lon, cell_degrees) + lats = np.arange(min_lat + cell_degrees / 2, max_lat, cell_degrees) + lon_grid, lat_grid = np.meshgrid(lons, lats) + return pd.DataFrame( + { + "lon": lon_grid.ravel(), + "lat": lat_grid.ravel(), + } + ) + + +def load_metadata_model(checkpoint: str, device: str) -> CropModel: + model = CropModel.load_from_checkpoint(checkpoint, map_location=device) + model.eval() + model.to(device) + if ( + getattr(model, "metadata_encoder", None) is None + or getattr(model, "classifier", None) is None + ): + raise ValueError( + "Checkpoint is not metadata-enabled. Expected CropModel.metadata_encoder " + "and CropModel.classifier." + ) + return model + + +def metadata_prior_scores( + model: CropModel, + grid: pd.DataFrame, + date: str, + device: str, +) -> pd.DataFrame: + """Compute metadata-only logits and probabilities for every grid cell/class.""" + metadata = torch.tensor( + np.column_stack( + [ + grid["lat"].to_numpy(), + grid["lon"].to_numpy(), + np.full(len(grid), day_of_year(date)), + ] + ), + dtype=torch.float32, + device=device, + ) + + with torch.no_grad(): + meta_features = model.metadata_encoder(metadata) + meta_dim = meta_features.shape[1] + classifier = model.classifier + meta_weights = classifier.weight[:, -meta_dim:] + logits = meta_features @ meta_weights.T + if classifier.bias is not None: + logits = logits + classifier.bias + probabilities = torch.softmax(logits, dim=1) + + labels = model.numeric_to_label_dict + rows = [] + logits_np = logits.cpu().numpy() + probs_np = probabilities.cpu().numpy() + for class_idx, label in labels.items(): + class_scores = pd.DataFrame( + { + "date": date, + "class_idx": class_idx, + "species": label, + "lat": grid["lat"].to_numpy(), + "lon": grid["lon"].to_numpy(), + "metadata_logit": logits_np[:, class_idx], + "metadata_probability": probs_np[:, class_idx], + } + ) + rows.append(class_scores) + return pd.concat(rows, ignore_index=True) + + +def select_species_scores(scores: pd.DataFrame, species: list[str]) -> pd.DataFrame: + available = set(scores["species"].unique()) + missing = [name for name in species if name not in available] + if missing: + examples = sorted(available)[:20] + raise ValueError( + f"Species not found in checkpoint label_dict: {missing}. " + f"First available labels: {examples}" + ) + + selected = scores[scores["species"].isin(species)].copy() + selected["relative_score"] = selected.groupby(["date", "species"])[ + "metadata_logit" + ].transform( + lambda x: (x - x.min()) / (x.max() - x.min()) if x.max() > x.min() else 0.0 + ) + return selected + + +def _safe_name(value: str) -> str: + return value.lower().replace(" ", "_").replace("/", "_") + + +def plot_species_map( + scores: pd.DataFrame, + species: str, + date: str, + bounds: tuple[float, float, float, float], + output_path: Path, + plot_column: str, + cell_degrees: float, + cmap: str, + use_basemap: bool, +) -> None: + subset = scores[(scores["species"] == species) & (scores["date"] == date)] + pivot = subset.pivot(index="lat", columns="lon", values=plot_column).sort_index() + min_lon, min_lat, max_lon, max_lat = bounds + + fig, ax = plt.subplots(figsize=(12, 8)) + ax.set_xlim(min_lon, max_lon) + ax.set_ylim(min_lat, max_lat) + ax.set_aspect("equal") + + if use_basemap and ctx is not None: + try: + ctx.add_basemap( + ax, + crs="EPSG:4326", + source=ctx.providers.Esri.OceanBasemap, + attribution_size=5, + zorder=0, + ) + except Exception as exc: + print(f"Could not add basemap tiles: {exc}") + + image = ax.imshow( + pivot.to_numpy(), + extent=[ + pivot.columns.min() - cell_degrees / 2, + pivot.columns.max() + cell_degrees / 2, + pivot.index.min() - cell_degrees / 2, + pivot.index.max() + cell_degrees / 2, + ], + origin="lower", + cmap=cmap, + alpha=0.75, + zorder=2, + vmin=0 if plot_column == "relative_score" else None, + vmax=1 if plot_column == "relative_score" else None, + ) + fig.colorbar(image, ax=ax, label=plot_column.replace("_", " ")) + ax.set_title(f"{species} metadata prior, {date}") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.grid(color="white", linewidth=0.3, alpha=0.4) + fig.savefig(output_path, dpi=250, bbox_inches="tight") + plt.close(fig) + + +def write_species_geotiff( + scores: pd.DataFrame, + species: str, + date: str, + output_path: Path, + plot_column: str, + cell_degrees: float, +) -> None: + subset = scores[(scores["species"] == species) & (scores["date"] == date)] + pivot = subset.pivot(index="lat", columns="lon", values=plot_column).sort_index() + array = np.flipud(pivot.to_numpy()).astype("float32") + transform = from_origin( + pivot.columns.min() - cell_degrees / 2, + pivot.index.max() + cell_degrees / 2, + cell_degrees, + cell_degrees, + ) + with rasterio.open( + output_path, + "w", + driver="GTiff", + height=array.shape[0], + width=array.shape[1], + count=1, + dtype="float32", + crs="EPSG:4326", + transform=transform, + ) as dst: + dst.write(array, 1) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Visualize metadata-only species priors from a CropModel checkpoint." + ) + parser.add_argument( + "--checkpoint", required=True, help="Metadata-enabled CropModel checkpoint." + ) + parser.add_argument( + "--species", + nargs="+", + default=DEFAULT_SPECIES, + help="Scientific names to map. Common aliases supported: Northern Gannet, Common Eider.", + ) + parser.add_argument( + "--dates", nargs="+", default=DEFAULT_DATES, help="YYYY-MM-DD dates to map." + ) + parser.add_argument( + "--bounds", + nargs=4, + type=float, + default=DEFAULT_BOUNDS, + metavar=("MIN_LON", "MIN_LAT", "MAX_LON", "MAX_LAT"), + ) + parser.add_argument( + "--cell-degrees", type=float, default=1.0, help="Grid cell size in degrees." + ) + parser.add_argument( + "--output-dir", type=Path, default=Path("outputs/metadata_prior_maps") + ) + parser.add_argument( + "--plot-column", + default="relative_score", + choices=["relative_score", "metadata_probability", "metadata_logit"], + help="Score column used for PNG coloring. CSV always contains all score columns.", + ) + parser.add_argument("--cmap", default="viridis") + parser.add_argument("--device", default="cpu") + parser.add_argument("--no-basemap", action="store_true") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + species = resolve_species(args.species) + grid = make_grid(tuple(args.bounds), args.cell_degrees) + model = load_metadata_model(args.checkpoint, args.device) + + all_scores = [] + for date in args.dates: + scores = metadata_prior_scores( + model=model, grid=grid, date=date, device=args.device + ) + selected = select_species_scores(scores, species) + all_scores.append(selected) + + for species_name in species: + output_stem = args.output_dir / f"{_safe_name(species_name)}_{date}" + plot_species_map( + scores=selected, + species=species_name, + date=date, + bounds=tuple(args.bounds), + output_path=output_stem.with_suffix(".png"), + plot_column=args.plot_column, + cell_degrees=args.cell_degrees, + cmap=args.cmap, + use_basemap=not args.no_basemap, + ) + write_species_geotiff( + scores=selected, + species=species_name, + date=date, + output_path=output_stem.with_suffix(".tif"), + plot_column=args.plot_column, + cell_degrees=args.cell_degrees, + ) + print(f"Wrote {output_stem.with_suffix('.png')}") + print(f"Wrote {output_stem.with_suffix('.tif')}") + + combined = pd.concat(all_scores, ignore_index=True) + csv_path = args.output_dir / "metadata_prior_scores.csv" + combined.to_csv(csv_path, index=False) + print(f"Wrote {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 63ef1a137..4059a2c52 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -142,6 +142,11 @@ cropmodel: normalize: # Number of pixels to expand bbox crop windows for better prediction context. expand: 0 + # Spatial-temporal metadata fusion (optional). + # When True, the model accepts (lat, lon, date) alongside image crops. + use_metadata: False + metadata_dim: 32 + metadata_dropout: 0.5 point: score_integration_radius: 5 diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index e9bcd133d..b39a20548 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -127,6 +127,9 @@ class CropModelConfig: resize_interpolation: str = "bilinear" normalize: Any = None expand: int = 0 + use_metadata: bool = False + metadata_dim: int = 32 + metadata_dropout: float = 0.5 @dataclass diff --git a/src/deepforest/datasets/cropmodel.py b/src/deepforest/datasets/cropmodel.py index 971098e3b..e80e545fa 100644 --- a/src/deepforest/datasets/cropmodel.py +++ b/src/deepforest/datasets/cropmodel.py @@ -8,6 +8,7 @@ import numpy as np import rasterio as rio +import torch from rasterio.windows import Window from torch.utils.data import Dataset from torchvision import transforms @@ -82,6 +83,7 @@ def __init__( resize_interpolation: str = "bilinear", normalize=None, expand: int = 0, + metadata=None, ): self.df = df @@ -100,6 +102,10 @@ def __init__( raise ValueError("expand must be >= 0") self.expand = int(expand) + # Optional spatial-temporal metadata per crop. + # Dict mapping crop index to (lat, lon, day_of_year). + self.metadata = metadata + unique_image = self.df["image_path"].unique() assert len(unique_image) == 1, ( "There should be only one unique image for this class object" @@ -149,4 +155,9 @@ def __getitem__(self, idx): else: image = box + if self.metadata is not None: + lat, lon, doy = self.metadata[idx] + meta_tensor = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, meta_tensor + return image diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index 6226ebd14..e2456b7be 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -1,5 +1,6 @@ """Dataset model for object detection tasks.""" +import datetime import math import os from abc import abstractmethod @@ -8,6 +9,7 @@ import cv2 import kornia.augmentation as K import numpy as np +import pandas as pd import shapely import torch import torchvision @@ -877,3 +879,65 @@ def _classes_in(root): ) return train_ds, val_ds + + +class MetadataImageFolder(Dataset): + """Wrapper that adds spatial-temporal metadata to an ImageFolder dataset. + + Expects a CSV sidecar file with columns: filename, lat, lon, date. + The date column should be an ISO format string (e.g., "2024-06-15") + and will be converted to day_of_year internally. + + Args: + image_folder: A FixedClassImageFolder (or ImageFolder) dataset. + metadata_csv: Path to CSV with columns filename, lat, lon, date. + + Returns per sample: + (image, label, metadata_tensor) where metadata_tensor is shape (3,) + containing [lat, lon, day_of_year]. + """ + + def __init__(self, image_folder, metadata_csv): + self.image_folder = image_folder + metadata_df = pd.read_csv(metadata_csv) + self._meta_lookup = {} + for _, row in metadata_df.iterrows(): + date = datetime.datetime.strptime(str(row["date"]), "%Y-%m-%d") + doy = float(date.timetuple().tm_yday) + self._meta_lookup[row["filename"]] = ( + float(row["lat"]), + float(row["lon"]), + doy, + ) + + def __len__(self): + return len(self.image_folder) + + def __getitem__(self, idx): + image, label = self.image_folder[idx] + filepath = self.image_folder.samples[idx][0] + filename = os.path.basename(filepath) + + if filename in self._meta_lookup: + lat, lon, doy = self._meta_lookup[filename] + else: + lat, lon, doy = 0.0, 0.0, 1.0 + + metadata = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, label, metadata + + @property + def targets(self): + return self.image_folder.targets + + @property + def class_to_idx(self): + return self.image_folder.class_to_idx + + @property + def samples(self): + return self.image_folder.samples + + @property + def imgs(self): + return self.image_folder.imgs diff --git a/src/deepforest/main.py b/src/deepforest/main.py index c30a9737a..9b6d5195d 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,4 +1,5 @@ # entry point for deepforest model +import datetime import importlib import os import warnings @@ -577,6 +578,7 @@ def predict_tile( iou_threshold=0.15, dataloader_strategy="single", crop_model=None, + metadata=None, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and @@ -593,6 +595,10 @@ def predict_tile( - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops + metadata: Optional dict with keys "lat", "lon", "date" for + spatial-temporal context. "date" should be an ISO format + string (e.g., "2024-06-15"). Used by CropModel when + use_metadata=True in config. Returns: pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple @@ -727,6 +733,20 @@ def predict_tile( root_dir = None if crop_model is not None: + # Build per-crop metadata from image-level metadata dict + if metadata is not None: + date_str = metadata.get("date", None) + if date_str is not None: + doy = float( + datetime.datetime.strptime(str(date_str), "%Y-%m-%d") + .timetuple() + .tm_yday + ) + else: + doy = 1.0 + lat = float(metadata.get("lat", 0.0)) + lon = float(metadata.get("lon", 0.0)) + cropmodel_results = [] for path in paths: image_result = mosaic_results[ @@ -735,8 +755,19 @@ def predict_tile( if image_result.empty: continue image_result.root_dir = os.path.dirname(path) + + # Create per-crop metadata dict if metadata was provided + per_crop_metadata = None + if metadata is not None: + per_crop_metadata = dict.fromkeys( + range(len(image_result)), (lat, lon, doy) + ) + cropmodel_result = predict._crop_models_wrapper_( - crop_model, self.trainer, image_result + crop_model, + self.trainer, + image_result, + metadata=per_crop_metadata, ) cropmodel_results.append(cropmodel_result) cropmodel_results = pd.concat(cropmodel_results) diff --git a/src/deepforest/model.py b/src/deepforest/model.py index 97ebaf91c..60a1e442c 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -1,5 +1,6 @@ # Model - common class import json +import math import os from collections.abc import Mapping @@ -104,6 +105,23 @@ def create_crop_backbone( return m +def create_crop_feature_backbone( + architecture: str = "resnet50", + pretrained: bool = True, +) -> tuple[torch.nn.Module, int]: + """Create a crop backbone that returns feature vectors.""" + if architecture not in _CROP_BACKBONES: + raise ValueError( + f"Unknown CropModel architecture '{architecture}'. " + f"Choose from {sorted(_CROP_BACKBONES)}." + ) + factory, default_weights = _CROP_BACKBONES[architecture] + m = factory(weights=default_weights if pretrained else None) + feature_dim = m.fc.in_features + m.fc = torch.nn.Identity() + return m, feature_dim + + def simple_resnet_50(num_classes: int = 2) -> torch.nn.Module: """Create a simple ResNet-50 model for classification. @@ -119,6 +137,66 @@ def simple_resnet_50(num_classes: int = 2) -> torch.nn.Module: return create_crop_backbone("resnet50", num_classes=num_classes) +def resnet50_backbone(): + """Create a ResNet-50 backbone that outputs 2048-dim feature vectors. + + Returns: + tuple: (backbone, feature_dim) where backbone is the model and + feature_dim is the output dimension (2048). + """ + return create_crop_feature_backbone("resnet50") + + +class SpatialTemporalEncoder(torch.nn.Module): + """Encode (lat, lon, day_of_year) into a fixed-size embedding. + + Uses sinusoidal features for smooth, periodic representation of + geographic coordinates and seasonality, followed by a small MLP. + + Args: + embed_dim: Output embedding dimension. Default 32. + dropout: Dropout rate on the embedding. Default 0.5. + + Input: + metadata: tensor of shape (batch, 3) with [lat, lon, day_of_year]. + lat in [-90, 90], lon in [-180, 180], day_of_year in [1, 366]. + + Output: + tensor of shape (batch, embed_dim). + """ + + def __init__(self, embed_dim: int = 32, dropout: float = 0.5): + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(6, embed_dim), + torch.nn.ReLU(), + torch.nn.Dropout(dropout), + ) + + def forward(self, metadata): + lat = metadata[:, 0:1] + lon = metadata[:, 1:2] + doy = metadata[:, 2:3] + + lat_norm = lat / 90.0 + lon_norm = lon / 180.0 + doy_norm = (doy - 1) / 365.0 + + features = torch.cat( + [ + torch.sin(math.pi * lat_norm), + torch.cos(math.pi * lat_norm), + torch.sin(math.pi * lon_norm), + torch.cos(math.pi * lon_norm), + torch.sin(2 * math.pi * doy_norm), + torch.cos(2 * math.pi * doy_norm), + ], + dim=1, + ) + + return self.mlp(features) + + class CropModel(LightningModule, PyTorchModelHubMixin): """A PyTorch Lightning module for classifying image crops from object detection models. @@ -150,6 +228,9 @@ def __init__( super().__init__() self.model = model + self.backbone = None + self.metadata_encoder = None + self.classifier = None # Set the argument as the self.config, this way when reloading the checkpoint, self.config exists and is not overwritten. self.config = config if self.config is None: @@ -210,21 +291,44 @@ def create_model(self, num_classes: int, architecture: str | None = None): } ) - self.model = create_crop_backbone( - architecture=architecture, - num_classes=num_classes, - ) + use_metadata = self.config["cropmodel"].get("use_metadata", False) + + if use_metadata: + metadata_dim = self.config["cropmodel"].get("metadata_dim", 32) + metadata_dropout = self.config["cropmodel"].get("metadata_dropout", 0.5) + + backbone, feature_dim = create_crop_feature_backbone( + architecture=architecture, + ) + self.backbone = backbone + self.metadata_encoder = SpatialTemporalEncoder( + embed_dim=metadata_dim, dropout=metadata_dropout + ) + self.classifier = torch.nn.Linear(feature_dim + metadata_dim, num_classes) + self.model = None + else: + self.backbone = None + self.metadata_encoder = None + self.classifier = None + self.model = create_crop_backbone( + architecture=architecture, + num_classes=num_classes, + ) def create_trainer(self, **kwargs): """Create a pytorch lightning trainer object.""" self.trainer = Trainer(**kwargs) - def load_from_disk(self, train_dir, val_dir): + def load_from_disk(self, train_dir, val_dir, metadata_csv=None): """Load the training and validation datasets from disk. Args: train_dir (str): The directory containing the training dataset. val_dir (str): The directory containing the validation dataset. + metadata_csv (str, optional): Path to a CSV file mapping image + filenames to spatial-temporal metadata. The CSV should have + columns: filename, lat, lon, date. Required when + use_metadata=True in config. Defaults to None. Returns: None @@ -235,6 +339,15 @@ def load_from_disk(self, train_dir, val_dir): transform_train=self.get_transform(augmentations=["HorizontalFlip"]), transform_val=self.get_transform(augmentations=None), ) + + if metadata_csv is not None and self.config["cropmodel"].get( + "use_metadata", False + ): + from deepforest.datasets.training import MetadataImageFolder + + self.train_ds = MetadataImageFolder(self.train_ds, metadata_csv) + self.val_ds = MetadataImageFolder(self.val_ds, metadata_csv) + self.label_dict = self.train_ds.class_to_idx # Create a reverse mapping from numeric indices to class labels @@ -242,7 +355,7 @@ def load_from_disk(self, train_dir, val_dir): self.num_classes = len(self.label_dict) - if self.model is None: + if self.model is None and self.backbone is None: self.create_model(self.num_classes) def get_transform(self, augmentations): @@ -403,14 +516,24 @@ def normalize(self): mean=list(norm_cfg["mean"]), std=list(norm_cfg["std"]) ) - def forward(self, x): - if self.model is None: + def forward(self, x, metadata=None): + if self.backbone is not None: + image_features = self.backbone(x) + if metadata is not None: + meta_features = self.metadata_encoder(metadata) + else: + meta_dim = self.classifier.in_features - image_features.shape[1] + meta_features = torch.zeros( + x.shape[0], meta_dim, device=x.device, dtype=x.dtype + ) + combined = torch.cat([image_features, meta_features], dim=1) + return self.classifier(combined) + elif self.model is not None: + return self.model(x) + else: raise AttributeError( "CropModel is not initialized. Provide 'num_classes' or load from a checkpoint." ) - output = self.model(x) - - return output def train_dataloader(self): """Train data loader.""" @@ -465,20 +588,27 @@ def val_dataloader(self): return val_loader def training_step(self, batch, batch_idx): - x, y = batch - outputs = self.forward(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self.forward(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("train_loss", loss) return loss def predict_step(self, batch, batch_idx): - # Check if batch is a tuple for validation_dataloader - if isinstance(batch, list): - x, y = batch + # Inference: batch may be (images, metadata), (images, labels, metadata), or a single images tensor. + if isinstance(batch, (list, tuple)) and len(batch) == 3: + images, _labels, metadata = batch + elif isinstance(batch, (list, tuple)) and len(batch) == 2: + images, metadata = batch else: - x = batch - outputs = self.forward(x) + images = batch + metadata = None + outputs = self.forward(images, metadata=metadata) yhat = F.softmax(outputs, 1) return yhat @@ -492,8 +622,12 @@ def postprocess_predictions(self, predictions): return label, score def validation_step(self, batch, batch_idx): - x, y = batch - outputs = self(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("val_loss", loss) diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 7396a023f..8703864b8 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -293,6 +293,7 @@ def _predict_crop_model_( augmentations=None, model_index=0, is_single_model=False, + metadata=None, ): """Predicts crop model on a raster file. @@ -340,6 +341,7 @@ def _predict_crop_model_( resize_interpolation=resize_interpolation, normalize=normalize, expand=expand, + metadata=metadata, ) # Create dataloader @@ -375,7 +377,7 @@ def _predict_crop_model_( def _crop_models_wrapper_( - crop_models, trainer, results, transform=None, augmentations=None + crop_models, trainer, results, transform=None, augmentations=None, metadata=None ): if crop_models is not None and not isinstance(crop_models, list): crop_models = [crop_models] @@ -398,6 +400,7 @@ def _crop_models_wrapper_( transform=transform, augmentations=augmentations, is_single_model=is_single_model, + metadata=metadata, ) crop_results.append(crop_result) diff --git a/tests/test_metadata_cropmodel.py b/tests/test_metadata_cropmodel.py new file mode 100644 index 000000000..fdbb2642e --- /dev/null +++ b/tests/test_metadata_cropmodel.py @@ -0,0 +1,178 @@ +"""Tests for spatial-temporal metadata embeddings in CropModel.""" + +import os + +import numpy as np +import pandas as pd +import pytest +import torch +from PIL import Image +from torchvision.datasets import ImageFolder + +from deepforest import get_data +from deepforest.datasets.cropmodel import BoundingBoxDataset +from deepforest.datasets.training import MetadataImageFolder +from deepforest.model import CropModel, SpatialTemporalEncoder + + +def test_spatial_temporal_encoder_output_shape(): + enc = SpatialTemporalEncoder(embed_dim=32, dropout=0.0) + meta = torch.tensor([[35.0, -120.0, 145.0], [0.0, 0.0, 1.0]]) + out = enc(meta) + assert out.shape == (2, 32) + + +def test_crop_model_metadata_forward(): + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 32}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + meta = torch.tensor([[35.0, -120.0, 145.0]] * 4) + out = cm.forward(x, metadata=meta) + assert out.shape == (4, 5) + + +def test_crop_model_metadata_none_graceful_degradation(): + """When use_metadata=True but metadata=None, model should still predict.""" + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x, metadata=None) + assert out.shape == (4, 5) + + +def test_crop_model_no_metadata_backward_compat(): + cm = CropModel() + cm.create_model(num_classes=2) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x) + assert out.shape == (4, 2) + assert cm.backbone is None + assert cm.metadata_encoder is None + assert cm.classifier is None + + +def test_training_step_with_metadata(): + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=3) + x = torch.rand(4, 3, 224, 224) + y = torch.tensor([0, 1, 2, 0]) + meta = torch.rand(4, 3) + batch = (x, y, meta) + loss = cm.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + + +def test_metadata_image_folder(tmp_path): + """Test MetadataImageFolder wrapping an ImageFolder.""" + for cls in ["A", "B"]: + cls_dir = tmp_path / cls + cls_dir.mkdir() + for i in range(3): + img = Image.fromarray(np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8)) + img.save(cls_dir / f"{cls}_{i}.png") + + rows = [] + for cls in ["A", "B"]: + for i in range(3): + rows.append({ + "filename": f"{cls}_{i}.png", + "lat": 35.0 + i, + "lon": -120.0 + i, + "date": "2024-06-15", + }) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame(rows).to_csv(metadata_csv, index=False) + + base_ds = ImageFolder(str(tmp_path)) + meta_ds = MetadataImageFolder(base_ds, str(metadata_csv)) + + assert len(meta_ds) == 6 + image, label, metadata = meta_ds[0] + assert isinstance(image, (torch.Tensor, np.ndarray, Image.Image)) + assert isinstance(label, int) + assert metadata.shape == (3,) + + found = False + for i in range(len(meta_ds)): + _, _, meta = meta_ds[i] + if meta[2].item() == 167.0: + found = True + break + assert found, "day_of_year should be 167 for 2024-06-15" + + +@pytest.fixture() +def bbox_df(): + df = pd.read_csv(get_data("testfile_multi.csv")) + single_image = df.image_path.unique()[0] + return df[df.image_path == single_image].reset_index(drop=True) + + +def test_bounding_box_dataset_with_metadata(bbox_df): + root_dir = os.path.dirname(get_data("SOAP_061.png")) + n = len(bbox_df) + metadata = dict.fromkeys(range(n), (35.0, -120.0, 145.0)) + ds = BoundingBoxDataset(bbox_df, root_dir=root_dir, metadata=metadata) + item = ds[0] + assert isinstance(item, tuple) + assert len(item) == 2 + assert item[0].shape[0] == 3 + assert item[1].shape == (3,) + assert item[1][0] == 35.0 + assert item[1][1] == -120.0 + assert item[1][2] == 145.0 + + +def test_checkpoint_save_load_metadata(tmp_path): + """Test that metadata models can be saved and loaded from checkpoint.""" + for cls in ["A", "B"]: + cls_dir = tmp_path / "data" / cls + cls_dir.mkdir(parents=True) + for i in range(3): + img = Image.fromarray( + np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + ) + img.save(cls_dir / f"{cls}_{i}.png") + + rows = [] + for cls in ["A", "B"]: + for i in range(3): + rows.append({ + "filename": f"{cls}_{i}.png", + "lat": 40.0, + "lon": -100.0, + "date": "2024-01-15", + }) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame(rows).to_csv(metadata_csv, index=False) + + data_dir = str(tmp_path / "data") + + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 16}) + cm.create_trainer( + fast_dev_run=False, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + default_root_dir=str(tmp_path / "logs"), + ) + cm.load_from_disk( + train_dir=data_dir, val_dir=data_dir, metadata_csv=str(metadata_csv) + ) + cm.create_model(num_classes=len(cm.label_dict)) + cm.trainer.fit(cm) + + checkpoint_path = str(tmp_path / "test.ckpt") + cm.trainer.save_checkpoint(checkpoint_path) + + loaded = CropModel.load_from_checkpoint(checkpoint_path) + assert loaded.backbone is not None + assert loaded.metadata_encoder is not None + assert loaded.classifier is not None + assert loaded.label_dict == cm.label_dict + + x = torch.rand(2, 3, 224, 224) + meta = torch.tensor([[40.0, -100.0, 15.0]] * 2) + out = loaded(x, metadata=meta) + assert out.shape == (2, len(cm.label_dict))