Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rclpy/rclpy/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,10 @@ def _wait_for_ready_callbacks(
ready_tasks_count = len(self._ready_tasks)
for _ in range(ready_tasks_count):
task = self._ready_tasks.popleft()
# Skip tasks that were cancelled or set done while awaiting a
# future and got rescheduled when the future completed
if task.cancelled() or task.done():
continue
task_data = self._pending_tasks[task]
node = task_data.source_node
if node is None or node in nodes_to_use:
Expand Down
34 changes: 28 additions & 6 deletions rclpy/rclpy/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(self, *, executor: Optional['Executor'] = None) -> None:
# An exception raised by the handler when called
self._exception: Optional[Exception] = None
self._exception_fetched = False
# callbacks to be scheduled after this task completes
self._callbacks: List[Callable[['Future[T]'], None]] = []
# callbacks or tasks to be scheduled after this task completes
self._callbacks: List[Union[Callable[['Future[T]'], None], 'Task[Any]']] = []
# Lock for threadsafety
self._lock = threading.Lock()
# An executor to use when scheduling done callbacks
Expand Down Expand Up @@ -165,10 +165,18 @@ def _schedule_or_invoke_done_callbacks(self) -> None:
if executor is not None:
# Have the executor take care of the callbacks
for callback in callbacks:
executor.create_task(callback, self)
if isinstance(callback, Task):
executor._call_task_in_next_spin(callback)
else:
executor.create_task(callback, self)
else:
# No executor, call right away
for callback in callbacks:
if isinstance(callback, Task):
warnings.warn(
'Dropping task awaiting future: '
'executor reference could not be resolved')
continue
try:
callback(self)
except Exception as e:
Expand Down Expand Up @@ -210,6 +218,21 @@ def add_done_callback(self, callback: Callable[['Future[T]'], None]) -> None:
if invoke:
callback(self)

def _add_waiting_task(self, task: 'Task[Any]') -> None:
"""Schedule a task to resume when this future completes."""
with self._lock:
if not self._pending():
assert self._executor is not None
executor = self._executor()
if executor is not None:
executor._call_task_in_next_spin(task)
else:
warnings.warn(
'Dropping task awaiting future: '
'executor reference could not be resolved')
else:
self._callbacks.append(task)

def remove_done_callback(self, callback: Callable[['Future[T]'], None]) -> bool:
"""
Remove a previously-added done callback.
Expand Down Expand Up @@ -352,9 +375,8 @@ def _add_resume_callback(self, future: Future[T], executor: 'Executor') -> None:
elif future_executor is not executor:
raise RuntimeError('A task can only await futures associated with the same executor')

# The future is associated with the same executor, so we can resume the task directly
# in the done callback
future.add_done_callback(lambda _: self.__call__())
# Register the task to resume when the future is done or cancelled
future._add_waiting_task(self)

def _complete_task(self) -> None:
"""Cleanup after task finished."""
Expand Down
80 changes: 80 additions & 0 deletions rclpy/test/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,86 @@ async def coro2():
self.assertTrue(future1.done())
self.assertEqual('Sentinel Result 1', future1.result())

def test_coroutine_exception_after_await(self) -> None:
"""Exception in a coroutine after awaiting a future must propagate."""
self.assertIsNotNone(self.node.handle)
# EventsExecutor excluded - segfaults on exception propagation (#1641)
for cls in [SingleThreadedExecutor, MultiThreadedExecutor]:
with self.subTest(cls=cls):
executor = cls(context=self.context)
executor.add_node(self.node)

first_fut = executor.create_future()
second_fut = executor.create_future()

async def coro_that_raises() -> None:
first_fut.set_result(None)
await second_fut
raise RuntimeError('Expected error after await')

task = executor.create_task(coro_that_raises)

executor.spin_until_future_complete(first_fut, timeout_sec=5)
self.assertFalse(task.done())
# Resolve the inner future — triggers resume
second_fut.set_result(None)

with self.assertRaises(RuntimeError) as cm:
executor.spin_until_future_complete(task, timeout_sec=5)
self.assertIn('Expected error after await', str(cm.exception))

def test_cancel_task_while_awaiting_future(self) -> None:
"""Cancelling a task parked on a future must not crash the dispatch loop."""
self.assertIsNotNone(self.node.handle)
# EventsExecutor excluded - see #1641
for cls in [SingleThreadedExecutor, MultiThreadedExecutor]:
with self.subTest(cls=cls):
executor = cls(context=self.context)
executor.add_node(self.node)

first_fut = executor.create_future()
second_fut = executor.create_future()
third_fut = executor.create_future()

async def coro() -> None:
first_fut.set_result(None)
await second_fut
third_fut.set_result(None)

task = executor.create_task(coro)

executor.spin_until_future_complete(first_fut, timeout_sec=5)
self.assertFalse(task.done())

task.cancel()
self.assertTrue(task.cancelled())

second_fut.set_result(None)

executor.spin_until_future_complete(first_fut, timeout_sec=5)
self.assertFalse(third_fut.done())

def test_await_already_completed_future(self) -> None:
"""Awaiting an already-completed future must resume and return its result."""
self.assertIsNotNone(self.node.handle)
# EventsExecutor excluded - see #1641
for cls in [SingleThreadedExecutor, MultiThreadedExecutor]:
with self.subTest(cls=cls):
executor = cls(context=self.context)
executor.add_node(self.node)

fut: Future[str] = executor.create_future()
fut.set_result('done') # complete before the task runs

async def coro() -> str:
return await fut # type: ignore[return-value]

task = executor.create_task(coro)

executor.spin_until_future_complete(task, timeout_sec=5)
self.assertTrue(task.done())
self.assertEqual('done', task.result())

def test_create_task_during_spin(self) -> None:
self.assertIsNotNone(self.node.handle)
for cls in [SingleThreadedExecutor, EventsExecutor]:
Expand Down