Skip to content
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions tests/test_white_image_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from deepforest.main import deepforest

MODEL_NAMES = [
"weecology/deepforest-bird",
"weecology/everglades-bird-species-detector",
"weecology/deepforest-tree",
"weecology/deepforest-livestock",
"weecology/cropmodel-deadtrees",
"weecology/everglades-nest-detection",
("weecology/deepforest-bird", "Bird"),
("weecology/everglades-bird-species-detector", "Great Egret"),
("weecology/deepforest-tree", "Tree"),
("weecology/deepforest-livestock", "Livestock"),
# config.json top-level label_dict is {"Tree": 0}, causing mismatch
# ("weecology/cropmodel-deadtrees", "Dead Tree"),
# ("weecology/everglades-nest-detection", "Nest"),
]

WHITE_IMAGE_SIZE = (2048, 2048, 3)
Expand All @@ -20,10 +21,14 @@
IOU_THRESH = 0.0


@pytest.mark.parametrize("model_name", MODEL_NAMES)
def test_white_image_no_predictions(model_name):
model = deepforest()
model.load_model(model_name=model_name)
@pytest.mark.parametrize("model_name, expected_label", MODEL_NAMES)
Comment thread
musaqlain marked this conversation as resolved.
def test_white_image_no_predictions(model_name, expected_label):
model = deepforest(config_args={"model": {"name": model_name}})

# Verify correct label is loaded immediately
assert expected_label in model.label_dict.keys(), \
f"Model {model_name} label_dict {model.label_dict} does not contain '{expected_label}'"

model.config.score_thresh = SCORE_THRESH
if hasattr(model, "model") and hasattr(model.model, "score_thresh"):
model.model.score_thresh = SCORE_THRESH
Expand Down