Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions rclpy/rclpy/action/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import threading
import time
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Dict
from typing import Generic
from typing import Optional
Expand All @@ -25,6 +28,7 @@
from typing import TYPE_CHECKING
from typing import TypedDict
from typing import TypeVar
from typing import Union

import uuid
import weakref
Expand Down Expand Up @@ -69,18 +73,22 @@ class ClientGoalHandleDict(TypedDict,
result: Tuple[int, GetResultServiceResponse[ClientGoalHandleDictResultT]]
feedback: FeedbackMessage[ClientGoalHandleDictFeedbackT]
status: GoalStatusArray

FeedbackCallbackUnion: TypeAlias = Union[
Callable[[FeedbackMessage[FeedbackT]], None],
Callable[[FeedbackMessage[FeedbackT]], Coroutine[Any, Any, None]],
]

class SendGoalKWargs(TypedDict, Generic[FeedbackT]):
feedback_callback: Optional[FeedbackCallbackUnion[FeedbackT]]
goal_uuid: Optional[UUID]
else:
ClientGoalHandleDict: 'TypeAlias' = Dict[str, object]


T = TypeVar('T')


class SendGoalKWargs(TypedDict):
feedback_callback: Optional[Callable[[FeedbackT], None]]
goal_uuid: Optional[UUID]


class ClientGoalHandle(Generic[GoalT, ResultT, FeedbackT]):
"""Goal handle for working with Action Clients."""

Expand Down Expand Up @@ -237,7 +245,7 @@ def __init__(
# key: result request sequence_number, value: UUID
self._result_sequence_number_to_goal_id: Dict[int, UUID] = {}
# key: UUID in bytes, value: callback function
self._feedback_callbacks: Dict[bytes, Callable[[FeedbackT], None]] = {}
self._feedback_callbacks: Dict[bytes, FeedbackCallbackUnion[FeedbackT]] = {}

self._logger = self._node.get_logger().get_child('action_client')
self._lock = threading.Lock()
Expand Down Expand Up @@ -444,7 +452,7 @@ def __exit__(self, t: Optional[Type[BaseException]], v: Optional[BaseException],

# End Waitable API

def send_goal(self, goal: GoalT, **kwargs: 'Unpack[SendGoalKWargs]'
def send_goal(self, goal: GoalT, **kwargs: 'Unpack[SendGoalKWargs[FeedbackT]]'
) -> Optional[GetResultServiceResponse[ResultT]]:
"""
Send a goal and wait for the result.
Expand Down Expand Up @@ -494,7 +502,7 @@ def unblock(future: Future[Any]) -> None:
def send_goal_async(
self,
goal: GoalT,
feedback_callback: Optional[Callable[[FeedbackT], None]] = None,
feedback_callback: Optional[FeedbackCallbackUnion[FeedbackT]] = None,
goal_uuid: Optional[UUID] = None
) -> Future[ClientGoalHandle[GoalT, ResultT, FeedbackT]]:
"""
Expand Down
68 changes: 44 additions & 24 deletions rclpy/rclpy/action/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from enum import Enum
import functools
import threading
import traceback

from types import TracebackType
from typing import (Any, Callable, Dict, Generic, Literal, Optional, Tuple, Type,
TYPE_CHECKING, TypedDict, TypeVar)
from typing import (Any, Callable, Coroutine, Dict, Generic, Literal, Optional, Tuple, Type,
TYPE_CHECKING, TypedDict, TypeVar, Union)


from action_msgs.msg import GoalInfo, GoalStatus
Expand All @@ -35,8 +37,8 @@
from rclpy.task import Future
from rclpy.task import Task
from rclpy.type_support import (Action, check_for_type_support, FeedbackMessage, FeedbackT,
GetResultServiceRequest, GetResultServiceResponse, GoalT, ResultT,
SendGoalServiceRequest)
GetResultServiceRequest, GetResultServiceResponse, GoalT, Msg,
ResultT, SendGoalServiceRequest)
from rclpy.waitable import NumberOfEntities, Waitable
from typing_extensions import TypeAlias
from unique_identifier_msgs.msg import UUID
Expand All @@ -55,6 +57,26 @@ class ServerGoalHandleDict(TypedDict,
cancel: Tuple['_rclpy.rmw_request_id_t', CancelGoal.Request]
result: Tuple['_rclpy.rmw_request_id_t', GetResultServiceRequest]
expired: Tuple[GoalInfo, ...]

ExecuteCallbackUnion: TypeAlias = Union[
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'], ResultT],
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'],
Coroutine[Any, Any, ResultT]],
]
GoalCallbackUnion: TypeAlias = Union[
Callable[[GoalT], 'GoalResponse'],
Callable[[GoalT], Coroutine[Any, Any, 'GoalResponse']],
]
HandleAcceptedCallbackUnion: TypeAlias = Union[
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'], None],
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'],
Coroutine[Any, Any, None]],
]
CancelCallbackUnion: TypeAlias = Union[
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'], 'CancelResponse'],
Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'],
Coroutine[Any, Any, 'CancelResponse']],
]
else:
ServerGoalHandleDict: TypeAlias = Dict[str, object]

Expand Down Expand Up @@ -173,8 +195,7 @@ def _set_result(self, response: Optional[ResultT]) -> None:

def execute(
self,
execute_callback: Optional[Callable[['ServerGoalHandle[GoalT, ResultT, FeedbackT]'],
ResultT]] = None
execute_callback: Optional[ExecuteCallbackUnion[GoalT, ResultT, FeedbackT]] = None
) -> None:
# It's possible that there has been a request to cancel the goal prior to executing.
# In this case we want to avoid the illegal state transition to EXECUTING
Expand Down Expand Up @@ -233,13 +254,15 @@ def default_handle_accepted_callback(goal_handle: ServerGoalHandle[Any, Any, Any


def default_goal_callback(
goal_request: SendGoalServiceRequest[Any]
goal_request: Msg
) -> Literal[GoalResponse.ACCEPT]:
"""Accept all goals."""
return GoalResponse.ACCEPT


def default_cancel_callback(cancel_request: CancelGoal.Request) -> Literal[CancelResponse.REJECT]:
def default_cancel_callback(
goal_handle: ServerGoalHandle[Any, Any, Any]
) -> Literal[CancelResponse.REJECT]:
"""No cancellations."""
return CancelResponse.REJECT

Expand All @@ -252,16 +275,14 @@ def __init__(
node: 'Node',
action_type: Type[Action],
action_name: str,
execute_callback: Optional[Callable[[ServerGoalHandle[GoalT, ResultT, FeedbackT]],
ResultT]] = None,
execute_callback: Optional[ExecuteCallbackUnion[GoalT, ResultT, FeedbackT]] = None,
*,
callback_group: 'Optional[CallbackGroup]' = None,
goal_callback: Callable[[CancelGoal.Request], GoalResponse] = default_goal_callback,
handle_accepted_callback: Callable[[ServerGoalHandle[GoalT,
ResultT,
FeedbackT]],
None] = default_handle_accepted_callback,
cancel_callback: Callable[[CancelGoal.Request], CancelResponse] = default_cancel_callback,
goal_callback: GoalCallbackUnion[GoalT] = default_goal_callback,
handle_accepted_callback: HandleAcceptedCallbackUnion[
GoalT, ResultT, FeedbackT] = default_handle_accepted_callback,
cancel_callback: CancelCallbackUnion[
GoalT, ResultT, FeedbackT] = default_cancel_callback,
goal_service_qos_profile: QoSProfile = qos_profile_services_default,
result_service_qos_profile: QoSProfile = qos_profile_services_default,
cancel_service_qos_profile: QoSProfile = qos_profile_services_default,
Expand Down Expand Up @@ -409,7 +430,7 @@ async def _execute_goal_request(

async def _execute_goal(
self,
execute_callback: Callable[[ServerGoalHandle[GoalT, ResultT, FeedbackT]], ResultT],
execute_callback: ExecuteCallbackUnion[GoalT, ResultT, FeedbackT],
goal_handle: ServerGoalHandle[GoalT, ResultT, FeedbackT]
) -> None:
goal_uuid = goal_handle.goal_id.uuid
Expand Down Expand Up @@ -632,8 +653,7 @@ def __exit__(self, t: Optional[Type[BaseException]],
def notify_execute(
self,
goal_handle: ServerGoalHandle[GoalT, ResultT, FeedbackT],
execute_callback: Optional[Callable[[ServerGoalHandle[GoalT, ResultT, FeedbackT]],
ResultT]]
execute_callback: Optional[ExecuteCallbackUnion[GoalT, ResultT, FeedbackT]] = None
) -> None:
# Use provided callback, defaulting to a previously registered callback
if execute_callback is None:
Expand All @@ -652,8 +672,8 @@ def notify_goal_done(self) -> None:

def register_handle_accepted_callback(
self,
handle_accepted_callback: Optional[Callable[[
ServerGoalHandle[GoalT, ResultT, FeedbackT]], None]]
handle_accepted_callback: Optional[
HandleAcceptedCallbackUnion[GoalT, ResultT, FeedbackT]] = None
) -> None:
"""
Register a callback for handling newly accepted goals.
Expand All @@ -677,7 +697,7 @@ def register_handle_accepted_callback(

def register_goal_callback(
self,
goal_callback: Optional[Callable[[SendGoalServiceRequest[GoalT]], GoalResponse]]
goal_callback: Optional[GoalCallbackUnion[GoalT]] = None
) -> None:
"""
Register a callback for handling new goal requests.
Expand All @@ -699,7 +719,7 @@ def register_goal_callback(

def register_cancel_callback(
self,
cancel_callback: Optional[Callable[[CancelGoal.Request], CancelResponse]]
cancel_callback: Optional[CancelCallbackUnion[GoalT, ResultT, FeedbackT]] = None
) -> None:
"""
Register a callback for handling cancel requests.
Expand All @@ -721,7 +741,7 @@ def register_cancel_callback(

def register_execute_callback(
self,
execute_callback: Callable[[ServerGoalHandle[GoalT, ResultT, FeedbackT]], ResultT]
execute_callback: ExecuteCallbackUnion[GoalT, ResultT, FeedbackT]
) -> None:
"""
Register a callback for executing action goals.
Expand Down
18 changes: 15 additions & 3 deletions rclpy/test/test_action_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import time
from typing import TYPE_CHECKING
import unittest
import uuid

Expand All @@ -29,6 +32,8 @@

from unique_identifier_msgs.msg import UUID

if TYPE_CHECKING:
from rclpy.type_support import FeedbackMessage

# TODO(jacobperron) Reduce fudge once wait_for_service uses node graph events
TIME_FUDGE = 0.3
Expand Down Expand Up @@ -72,6 +77,13 @@ def publish_feedback(self, goal_id):

class TestActionClient(unittest.TestCase):

if TYPE_CHECKING:
context: rclpy.context.Context
executor: SingleThreadedExecutor
node: rclpy.node.Node
mock_action_server: MockActionServer
feedback: FeedbackMessage[Fibonacci.Feedback] | None

@classmethod
def setUpClass(cls):
cls.context = rclpy.context.Context()
Expand All @@ -88,7 +100,7 @@ def tearDownClass(cls):
def setUp(self) -> None:
self.feedback = None

def feedback_callback(self, feedback):
def feedback_callback(self, feedback: FeedbackMessage[Fibonacci.Feedback]) -> None:
self.feedback = feedback

def timed_spin(self, duration):
Expand Down Expand Up @@ -365,9 +377,9 @@ def test_different_type_raises(self) -> None:
ac = ActionClient(self.node, Fibonacci, 'fibonacci')
try:
with self.assertRaises(TypeError):
ac.send_goal('different goal type')
ac.send_goal('different goal type') # type: ignore[call-arg,arg-type]
with self.assertRaises(TypeError):
ac.send_goal_async('different goal type')
ac.send_goal_async('different goal type') # type: ignore[arg-type]
finally:
ac.destroy()

Expand Down
7 changes: 4 additions & 3 deletions rclpy/test/test_action_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import time
from typing import Any
import unittest
import uuid

Expand Down Expand Up @@ -291,7 +292,7 @@ def execute_callback(goal_handle):
goal_handle.canceled()
return Fibonacci.Result()

def cancel_callback(request):
def cancel_callback(goal_handle: ServerGoalHandle[Any, Any, Any]) -> CancelResponse:
return CancelResponse.ACCEPT

executor = MultiThreadedExecutor(context=self.context)
Expand Down Expand Up @@ -337,7 +338,7 @@ def execute_callback(goal_handle):
goal_handle.canceled()
return Fibonacci.Result()

def cancel_callback(request):
def cancel_callback(goal_handle: ServerGoalHandle[Any, Any, Any]) -> CancelResponse:
return CancelResponse.REJECT

executor = MultiThreadedExecutor(context=self.context)
Expand Down Expand Up @@ -380,7 +381,7 @@ def handle_accepted_callback(gh):
nonlocal server_goal_handle
server_goal_handle = gh

def cancel_callback(request):
def cancel_callback(goal_handle: ServerGoalHandle[Any, Any, Any]) -> CancelResponse:
return CancelResponse.ACCEPT

def execute_callback(gh):
Expand Down