diff --git a/distributed/core.py b/distributed/core.py index dc5f7733fe..3d792844cd 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -16,7 +16,6 @@ import weakref from collections import defaultdict, deque from collections.abc import ( - Awaitable, Callable, Container, Coroutine, @@ -560,33 +559,18 @@ def start_periodic_callbacks(self): if not pc.is_running(): pc.start() - def _stop_listeners(self) -> asyncio.Future: - listeners_to_stop: set[Awaitable] = set() - + def _stop_listeners(self) -> None: for listener in self.listeners: - future = listener.stop() - if inspect.isawaitable(future): - warnings.warn( - f"{type(listener)} is using an asynchronous `stop` method. " - "Support for asynchronous `Listener.stop` has been deprecated and " - "will be removed in a future version", - DeprecationWarning, - ) - listeners_to_stop.add(future) - elif hasattr(listener, "abort_handshaking_comms"): + listener.stop() + if hasattr(listener, "abort_handshaking_comms"): listener.abort_handshaking_comms() - return asyncio.gather(*listeners_to_stop) - def stop(self) -> None: if self.__stopped: return self.__stopped = True self.monitor.close() - if not (stop_listeners := self._stop_listeners()).done(): - self._ongoing_background_tasks.call_soon( - asyncio.wait_for(stop_listeners, timeout=None) # type: ignore[arg-type] - ) + self._stop_listeners() if self._workdir is not None: self._workdir.release() @@ -935,7 +919,7 @@ async def close(self, timeout: float | None = None, reason: str = "") -> None: self.__stopped = True self.monitor.close() - await self._stop_listeners() + self._stop_listeners() # TODO: Deal with exceptions await self._ongoing_background_tasks.stop() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 78768f4bdd..73ee0f28ba 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -9,6 +9,7 @@ import sys import threading import weakref +from contextlib import asynccontextmanager from unittest import mock import pytest @@ -16,10 +17,11 @@ import dask +from distributed import core from distributed.batched import BatchedSend from distributed.comm.core import CommClosedError, FatalCommClosedError from distributed.comm.registry import backends -from distributed.comm.tcp import TCPBackend, TCPListener +from distributed.comm.tcp import TCPBackend, TCPConnector from distributed.core import ( ConnectionPool, RPCClosed, @@ -171,13 +173,11 @@ class MyServer(Server): default_port = 8756 -@pytest.mark.slow @gen_test() async def test_server_listen(): """ Test various Server.listen() arguments and their effect. """ - import socket try: EXTERNAL_IP4 = get_ip() @@ -186,8 +186,6 @@ async def test_server_listen(): except socket.gaierror: pytest.skip("no network access") - from contextlib import asynccontextmanager - @asynccontextmanager async def listen_on(cls, *args, **kwargs): server = cls({}) @@ -631,8 +629,6 @@ async def test_connection_pool_close_while_connecting(monkeypatch): Ensure a closed connection pool guarantees to have no connections left open even if it is closed mid-connecting """ - from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector class SlowConnector(TCPConnector): async def connect(self, address, deserialize, **connection_args): @@ -672,8 +668,6 @@ async def connect_to_server(): @gen_test() async def test_connection_pool_outside_cancellation(monkeypatch): # Ensure cancellation errors are properly reraised - from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector class SlowConnector(TCPConnector): async def connect(self, address, deserialize, **connection_args): @@ -707,11 +701,9 @@ async def connect_to_server(): assert all(t.cancelled() for t in tasks) +@pytest.mark.slow @gen_test() async def test_connection_pool_catch_all_cancellederrors(monkeypatch): - from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector - in_connect = asyncio.Event() block_connect = asyncio.Event() @@ -922,7 +914,6 @@ async def test_ticks(s, a, b): @gen_cluster(config={"distributed.admin.tick.interval": "20 ms"}) async def test_tick_logging(s, a, b): pytest.importorskip("crick") - from distributed import core old = core.tick_maximum_delay core.tick_maximum_delay = 0.001 @@ -1289,25 +1280,6 @@ def stream_not_leading_position(self, other, stream): ... assert not _expects_comm(instance.stream_not_leading_position) -class AsyncStopTCPListener(TCPListener): - async def stop(self): - await asyncio.sleep(0) - super().stop() - - -class TCPAsyncListenerBackend(TCPBackend): - _listener_class = AsyncStopTCPListener - - -@gen_test() -async def test_async_listener_stop(monkeypatch): - monkeypatch.setitem(backends, "tcp", TCPAsyncListenerBackend()) - with pytest.warns(DeprecationWarning): - async with Server({}) as s: - await s.listen(0) - assert s.listeners - - @gen_test() async def test_messages_are_ordered_bsend(): ledger = [] diff --git a/distributed/worker.py b/distributed/worker.py index 79f24b22da..0c55a83aac 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1620,7 +1620,7 @@ async def close( # type: ignore # otherwise c.close() - await self._stop_listeners() + self._stop_listeners() await self.rpc.close() # Give some time for a UCX scheduler to complete closing endpoints