diff --git a/src/deepforest/main.py b/src/deepforest/main.py index c30a9737a..7a73264c5 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytorch_lightning as pl +import rasterio import torch import torchmetrics from omegaconf import DictConfig, OmegaConf @@ -577,6 +578,7 @@ def predict_tile( iou_threshold=0.15, dataloader_strategy="single", crop_model=None, + project=False, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and @@ -593,9 +595,10 @@ def predict_tile( - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops + project (bool): If True, return a geopandas.GeoDataFrame with geometry column projected to the image CRS. Returns: - pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple + pd.DataFrame, geopandas.GeoDataFrame, or tuple: Predictions dataframe, geopandas.GeoDataFrame, or (predictions, crops) tuple. """ self.model.eval() self.model.nms_thresh = self.config.nms_thresh @@ -746,6 +749,28 @@ def predict_tile( formatted_results = utilities.__pandas_to_geodataframe__(cropmodel_results) formatted_results.root_dir = root_dir + if project: + if paths[0] is None: + raise ValueError( + "project=True requires a file path, not an in-memory image array." + ) + + if root_dir is None: + root_dir = os.path.dirname(paths[0]) + + rgb_path = os.path.join(root_dir, os.path.basename(paths[0])) + with rasterio.open(rgb_path) as src: + if src.crs is None: + raise ValueError( + f"project=True requires a georeferenced image, " + f"but '{paths[0]}' has no CRS. Use a georeferenced " + f"raster (e.g., a GeoTIFF) or set project=False." + ) + + formatted_results = utilities.image_to_geo_coordinates( + formatted_results, root_dir=root_dir + ) + return formatted_results def training_step(self, batch, batch_idx):