diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 106373411..07a797318 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -740,6 +740,10 @@ def _wait_for_ready_callbacks( ready_tasks_count = len(self._ready_tasks) for _ in range(ready_tasks_count): task = self._ready_tasks.popleft() + # Skip tasks that were cancelled or set done while awaiting a + # future and got rescheduled when the future completed + if task.cancelled() or task.done(): + continue task_data = self._pending_tasks[task] node = task_data.source_node if node is None or node in nodes_to_use: diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 65fe2bbad..9792cb948 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -51,8 +51,8 @@ def __init__(self, *, executor: Optional['Executor'] = None) -> None: # An exception raised by the handler when called self._exception: Optional[Exception] = None self._exception_fetched = False - # callbacks to be scheduled after this task completes - self._callbacks: List[Callable[['Future[T]'], None]] = [] + # callbacks or tasks to be scheduled after this task completes + self._callbacks: List[Union[Callable[['Future[T]'], None], 'Task[Any]']] = [] # Lock for threadsafety self._lock = threading.Lock() # An executor to use when scheduling done callbacks @@ -165,10 +165,18 @@ def _schedule_or_invoke_done_callbacks(self) -> None: if executor is not None: # Have the executor take care of the callbacks for callback in callbacks: - executor.create_task(callback, self) + if isinstance(callback, Task): + executor._call_task_in_next_spin(callback) + else: + executor.create_task(callback, self) else: # No executor, call right away for callback in callbacks: + if isinstance(callback, Task): + warnings.warn( + 'Dropping task awaiting future: ' + 'executor reference could not be resolved') + continue try: callback(self) except Exception as e: @@ -210,6 +218,21 @@ def add_done_callback(self, callback: Callable[['Future[T]'], None]) -> None: if invoke: callback(self) + def _add_waiting_task(self, task: 'Task[Any]') -> None: + """Schedule a task to resume when this future completes.""" + with self._lock: + if not self._pending(): + assert self._executor is not None + executor = self._executor() + if executor is not None: + executor._call_task_in_next_spin(task) + else: + warnings.warn( + 'Dropping task awaiting future: ' + 'executor reference could not be resolved') + else: + self._callbacks.append(task) + def remove_done_callback(self, callback: Callable[['Future[T]'], None]) -> bool: """ Remove a previously-added done callback. @@ -352,9 +375,8 @@ def _add_resume_callback(self, future: Future[T], executor: 'Executor') -> None: elif future_executor is not executor: raise RuntimeError('A task can only await futures associated with the same executor') - # The future is associated with the same executor, so we can resume the task directly - # in the done callback - future.add_done_callback(lambda _: self.__call__()) + # Register the task to resume when the future is done or cancelled + future._add_waiting_task(self) def _complete_task(self) -> None: """Cleanup after task finished.""" diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 95dac0439..0fdeb0c92 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -460,6 +460,86 @@ async def coro2(): self.assertTrue(future1.done()) self.assertEqual('Sentinel Result 1', future1.result()) + def test_coroutine_exception_after_await(self) -> None: + """Exception in a coroutine after awaiting a future must propagate.""" + self.assertIsNotNone(self.node.handle) + # EventsExecutor excluded - segfaults on exception propagation (#1641) + for cls in [SingleThreadedExecutor, MultiThreadedExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + executor.add_node(self.node) + + first_fut = executor.create_future() + second_fut = executor.create_future() + + async def coro_that_raises() -> None: + first_fut.set_result(None) + await second_fut + raise RuntimeError('Expected error after await') + + task = executor.create_task(coro_that_raises) + + executor.spin_until_future_complete(first_fut, timeout_sec=5) + self.assertFalse(task.done()) + # Resolve the inner future — triggers resume + second_fut.set_result(None) + + with self.assertRaises(RuntimeError) as cm: + executor.spin_until_future_complete(task, timeout_sec=5) + self.assertIn('Expected error after await', str(cm.exception)) + + def test_cancel_task_while_awaiting_future(self) -> None: + """Cancelling a task parked on a future must not crash the dispatch loop.""" + self.assertIsNotNone(self.node.handle) + # EventsExecutor excluded - see #1641 + for cls in [SingleThreadedExecutor, MultiThreadedExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + executor.add_node(self.node) + + first_fut = executor.create_future() + second_fut = executor.create_future() + third_fut = executor.create_future() + + async def coro() -> None: + first_fut.set_result(None) + await second_fut + third_fut.set_result(None) + + task = executor.create_task(coro) + + executor.spin_until_future_complete(first_fut, timeout_sec=5) + self.assertFalse(task.done()) + + task.cancel() + self.assertTrue(task.cancelled()) + + second_fut.set_result(None) + + executor.spin_until_future_complete(first_fut, timeout_sec=5) + self.assertFalse(third_fut.done()) + + def test_await_already_completed_future(self) -> None: + """Awaiting an already-completed future must resume and return its result.""" + self.assertIsNotNone(self.node.handle) + # EventsExecutor excluded - see #1641 + for cls in [SingleThreadedExecutor, MultiThreadedExecutor]: + with self.subTest(cls=cls): + executor = cls(context=self.context) + executor.add_node(self.node) + + fut: Future[str] = executor.create_future() + fut.set_result('done') # complete before the task runs + + async def coro() -> str: + return await fut # type: ignore[return-value] + + task = executor.create_task(coro) + + executor.spin_until_future_complete(task, timeout_sec=5) + self.assertTrue(task.done()) + self.assertEqual('done', task.result()) + def test_create_task_during_spin(self) -> None: self.assertIsNotNone(self.node.handle) for cls in [SingleThreadedExecutor, EventsExecutor]: