Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from deepforest import main, get_data

def test_benchmark_release():
"""
Benchmark test to ensure the specific release version of the model
produces consistent results.
"""
release_sha = "cc21436bc5d572dde8ff5f93c1e71a32f563cace"

m = main.deepforest()
m.load_model("weecology/deepforest-tree", revision=release_sha)

csv_file = get_data("OSBS_029.csv")
results = m.evaluate(csv_file, iou_threshold=0.4)

assert results["box_precision"] == pytest.approx(0.8, abs=0.01)
assert results["box_recall"] == pytest.approx(0.7213, abs=0.01)
9 changes: 4 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,11 @@ def test_evaluate(m):
df = pd.read_csv(csv_file)
results = m.evaluate(csv_file)

# Metrics are sane
assert np.round(results["box_precision"], 2) > 0.5
assert np.round(results["box_recall"], 2) > 0.5
# Check that precision and recall don't regress below reasonable baselines
assert results["box_precision"] > 0.7
assert results["box_recall"] > 0.5

# Class names are correct
# Structure and Label checks
assert len(results["results"].predicted_label.dropna().unique()) == 1
assert results["results"].predicted_label.dropna().unique()[0] == "Tree"
assert results["predictions"].shape[0] > 0
Expand All @@ -578,7 +578,6 @@ def test_evaluate(m):
# Check we have match results for every ground truth box
assert results["results"].shape[0] == df.shape[0]


def test_train_callbacks(m):
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
Expand Down
Loading