Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,34 +1065,43 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
raise

conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise

if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
await conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
Comment on lines +1097 to +1104

@contextlib.asynccontextmanager
async def checkout(
Expand Down
91 changes: 54 additions & 37 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
sock_returned = False
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
Expand All @@ -223,14 +224,18 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout
)
sock.settimeout(timeout)
# Set immediately before return. Do not insert an await between this and the return
sock_returned = True
return sock
except asyncio.TimeoutError as e:
sock.close()
err = socket.timeout("timed out")
err.__cause__ = e
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
finally:
# Always close the socket if it wasn't returned to avoid leaks.
if not sock_returned:
sock.close()
Comment on lines +235 to +238

if err is not None:
raise err
Expand Down Expand Up @@ -307,48 +312,60 @@ async def _configured_protocol_interface(
Sets protocol's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
timeout = options.socket_timeout
sock_adopted = False
try:
ssl_context = options._ssl_context
timeout = options.socket_timeout

if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
if ssl_context is None:
result = await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
sock_adopted = True
return AsyncNetworkingInterface(result)

host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
except _CertificateError:
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
host = address[0]
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
sock_adopted = True
except _CertificateError:
transport.abort()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise

return AsyncNetworkingInterface((transport, protocol))
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(
address, exc, "SSL handshake failed: ", timeout_details=details
)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
finally:
# If cancellation or any exception lands between socket creation and
# transport adoption, asyncio.create_connection has not registered
# cleanup for the sock.
# Close it ourselves to prevent leaks.
if not sock_adopted:
sock.close()


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
57 changes: 33 additions & 24 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,34 +1061,43 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
raise

conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise

if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)
if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise

@contextlib.contextmanager
def checkout(
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ filterwarnings = [
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/1032
"module:.*WindowsSelectorEventLoopPolicy:DeprecationWarning",
"module:.*et_event_loop_policy:DeprecationWarning",
# TODO: Remove as part of PYTHON-3923.
"module:unclosed <socket.socket:ResourceWarning",
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
"module:unclosed event loop:ResourceWarning",
# https://github.com/dateutil/dateutil/issues/1314
Expand Down
2 changes: 2 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,8 @@ def setup():

def teardown():
global_knobs.disable()
if client_context.client is not None:
client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")
Expand Down
2 changes: 2 additions & 0 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,8 @@ async def async_setup():

async def async_teardown():
global_knobs.disable()
if async_client_context.client is not None:
await async_client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")
Expand Down
Loading