Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def _integer_population_size(a):
return None


def _choice_indices(population_size, size, num_samples, replace, p, device, *, shuffle=True):
def _choice_indices(
population_size, size, num_samples, replace, p, device, *, shuffle=True
):
if population_size <= 0:
if num_samples == 0:
return _torch.empty(size or (0,), dtype=_torch.long, device=device)
Expand Down
17 changes: 10 additions & 7 deletions src/pyrecest/evaluation/generate_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,16 @@ def generate_measurements(groundtruth, simulation_config):
measurements[t] = curr_dist.sample(n_meas)
elif isinstance(meas_noise, GaussianDistribution):
noise_samples = meas_noise.sample(n_meas)
measurements[t] = tile(
groundtruth[t],
(
n_meas,
1,
),
) + noise_samples
measurements[t] = (
tile(
groundtruth[t],
(
n_meas,
1,
),
)
+ noise_samples
)

return measurements

Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/evaluation/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,5 +369,7 @@ def _coerce_finite_threshold(value: Any, column: str) -> float:
f"Constraint threshold for {column!r} must be a finite scalar."
) from exc
if not np.isfinite(threshold):
raise ValueError(f"Constraint threshold for {column!r} must be a finite scalar.")
raise ValueError(
f"Constraint threshold for {column!r} must be a finite scalar."
)
return threshold
4 changes: 3 additions & 1 deletion src/pyrecest/filters/association_hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def filter(
),
)
for hypothesis in sorted_group[: self.k]:
accepted_keys.add((_track_index(hypothesis), _measurement_index(hypothesis)))
accepted_keys.add(
(_track_index(hypothesis), _measurement_index(hypothesis))
)

result = []
for hypothesis in hypotheses:
Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
and capability-oriented so filters can opt into the pieces they need.
"""

from ._validated_motion_models import nearly_coordinated_turn_model
from .adapters import (
LinearMeasurementArguments,
LinearTransitionArguments,
Expand Down Expand Up @@ -66,7 +67,6 @@
white_noise_jerk_covariance,
white_noise_snap_covariance,
)
from ._validated_motion_models import nearly_coordinated_turn_model
from .sensor_models import (
bearing_only_measurement,
bearing_only_model,
Expand Down
8 changes: 5 additions & 3 deletions src/pyrecest/models/_validated_motion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ def nearly_coordinated_turn_model(
dt,
"dt",
)
turn_rate_variance = _motion_models._as_nonnegative_float( # pylint: disable=protected-access
turn_rate_variance,
"turn_rate_variance",
turn_rate_variance = (
_motion_models._as_nonnegative_float( # pylint: disable=protected-access
turn_rate_variance,
"turn_rate_variance",
)
)
return _nearly_coordinated_turn_model_impl(
dt=dt,
Expand Down
4 changes: 1 addition & 3 deletions tests/backend/test_numpy_random_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def test_choice_without_replacement_shuffle_false_preserves_order():
matrix = np.array([[10, 20, 30], [40, 50, 60]])

random.seed(0)
samples = random.choice(
values, size=values.shape[0], replace=False, shuffle=False
)
samples = random.choice(values, size=values.shape[0], replace=False, shuffle=False)
column_samples = random.choice(
matrix,
size=matrix.shape[1],
Expand Down
4 changes: 1 addition & 3 deletions tests/backend/test_pytorch_random_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def test_choice_without_replacement_shuffle_false_preserves_order():
matrix = torch.tensor([[10, 20, 30], [40, 50, 60]])

random.seed(0)
samples = random.choice(
values, size=values.shape[0], replace=False, shuffle=False
)
samples = random.choice(values, size=values.shape[0], replace=False, shuffle=False)
column_samples = random.choice(
matrix,
size=matrix.shape[1],
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class DistributionWithCallableDimAndMean:
mu = array([0.0, 1.0])

def dim(self):
raise AssertionError("dim() must not be called when methods are disabled")
raise AssertionError(
"dim() must not be called when methods are disabled"
)

self.assertEqual(
infer_state_dim_from_distribution(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_multisession_assignment_observation_costs_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import unittest

from pyrecest.backend import __backend_name__, array # pylint: disable=no-name-in-module
from pyrecest.backend import ( # pylint: disable=no-name-in-module
__backend_name__,
array,
)
from pyrecest.utils import solve_multisession_assignment_with_observation_costs


Expand Down
2 changes: 1 addition & 1 deletion tests/test_track_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import unittest

import numpy as np

from pyrecest.utils.track_completion import (
CompletionCandidate,
enumerate_fragment_completion_paths,
Expand Down Expand Up @@ -69,6 +68,7 @@ def test_rejects_negative_candidate_observations(self) -> None:

for invalid_candidate in invalid_candidates:
with self.subTest(invalid_candidate=invalid_candidate):

def provider(session: int, observation: int, target_session: int):
del session, observation, target_session
return [invalid_candidate]
Expand Down