diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index 39f50d01..3ccc9786 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -55,6 +55,8 @@ def __init__( context: zmq.Context | None = None, session: Session | None = None, address: t.Union[t.Tuple[str, int], str] = "", + *, + curve_serverkey: bytes | None = None, ) -> None: """Create the heartbeat monitor thread. @@ -66,12 +68,17 @@ def __init__( The session to use. address : zmq url Standard (ip, port) tuple that the kernel is listening on. + curve_serverkey : bytes, optional + CurveZMQ server public key (Z85). When provided, the + heartbeat REQ socket is configured as a CurveZMQ client so it + can communicate with a CurveZMQ-enabled kernel. """ super().__init__() self.daemon = True self.context = context self.session = session + self.curve_serverkey = curve_serverkey if isinstance(address, tuple): if address[1] == 0: message = "The port number for a channel cannot be 0." @@ -104,6 +111,13 @@ def _create_socket(self) -> None: assert self.context is not None self.socket = self.context.socket(zmq.REQ) self.socket.linger = 1000 + if self.curve_serverkey is not None: + # Generate a fresh ephemeral keypair for each socket; only the + # server public key (curve_serverkey) is needed for authentication. + client_pub, client_sec = zmq.curve_keypair() + self.socket.curve_secretkey = client_sec + self.socket.curve_publickey = client_pub + self.socket.curve_serverkey = self.curve_serverkey assert self.address is not None self.socket.connect(self.address) diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 1a0c8b01..7feb38c4 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -392,9 +392,26 @@ def hb_channel(self) -> t.Any: if self._hb_channel is None: url = self._make_url("hb") self.log.debug("connecting heartbeat channel to %s", url) - self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract] - self.context, self.session, url - ) + hb_kwargs = {} + if self.curve_publickey: + hb_kwargs["curve_serverkey"] = self.curve_publickey + try: + self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract] + self.context, + self.session, + url, + **hb_kwargs, + ) + except TypeError as e: + if "curve_serverkey" in str(e): + msg = ( + f"{self.hb_channel_class.__name__} does not support the " + "'curve_serverkey' parameter. Upgrade the heartbeat channel " + "class or disable CurveZMQ encryption." + ) + raise RuntimeError(msg) from e + else: + raise return self._hb_channel @property diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index f4667545..5a2b3cca 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -17,12 +17,13 @@ import tempfile import warnings from getpass import getpass -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast import zmq from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write -from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe +from traitlets import Bool, Bytes, CaselessStrEnum, Instance, Integer, Type, Unicode, observe from traitlets.config import LoggingConfigurable, SingletonConfigurable +from typing_extensions import TypedDict from .localinterfaces import localhost from .utils import _filefind @@ -33,7 +34,22 @@ from .session import Session # Define custom type for kernel connection info -KernelConnectionInfo = dict[str, Union[int, str, bytes]] + + +class KernelConnectionInfo(TypedDict, extra_items=str | bytes | int, total=False): # type: ignore[call-arg] + shell_port: int + iopub_port: int + stdin_port: int + control_port: int + hb_port: int + ip: str + key: str + transport: str + signature_scheme: str + kernel_name: str + session: Session + curve_publickey: str + curve_secretkey: str def write_connection_file( @@ -48,6 +64,8 @@ def write_connection_file( transport: str = "tcp", signature_scheme: str = "hmac-sha256", kernel_name: str = "", + curve_publickey: bytes | None = None, + curve_secretkey: bytes | None = None, **kwargs: Any, ) -> tuple[str, KernelConnectionInfo]: """Generates a JSON config file, including the selection of random ports. @@ -76,7 +94,7 @@ def write_connection_file( ip : str, optional The ip address the kernel will bind to. - key : str, optional + key : bytes, optional The Session key used for message authentication. signature_scheme : str, optional @@ -89,6 +107,12 @@ def write_connection_file( kernel_name : str, optional The name of the kernel currently connected to. + + curve_publickey : bytes, optional + CurveZMQ public key (Z85). + + curve_secretkey : bytes, optional + CurveZMQ secret key (Z85). """ if not ip: ip = localhost() @@ -149,7 +173,11 @@ def write_connection_file( cfg["transport"] = transport cfg["signature_scheme"] = signature_scheme cfg["kernel_name"] = kernel_name - cfg.update(kwargs) + if curve_publickey is not None: + cfg["curve_publickey"] = curve_publickey.decode("ascii") + if curve_secretkey is not None: + cfg["curve_secretkey"] = curve_secretkey.decode("ascii") + cfg.update(kwargs) # type: ignore[typeddict-item] # Only ever write this file as user read/writeable # This would otherwise introduce a vulnerability as a file has secrets @@ -371,6 +399,11 @@ def _ip_changed(self, change: Any) -> None: stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]") control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]") + # Optional CurveZMQ keys loaded from the connection file (Z85-encoded bytes). + # None when the kernel was not started with CurveZMQ enabled. + curve_publickey: Bytes | None = Bytes(allow_none=True, default_value=None) + curve_secretkey: Bytes | None = Bytes(allow_none=True, default_value=None) + # names of the ports with random assignment _random_port_names: list[str] | None = None @@ -405,7 +438,7 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo: connect_info : dict dictionary of connection information. """ - info = { + info: KernelConnectionInfo = { "transport": self.transport, "ip": self.ip, "shell_port": self.shell_port, @@ -426,6 +459,9 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo: "key": self.session.key, } ) + if self.curve_publickey is not None and self.curve_secretkey is not None: + info["curve_publickey"] = self.curve_publickey.decode() + info["curve_secretkey"] = self.curve_secretkey.decode() return info # factory for blocking clients @@ -515,7 +551,7 @@ def write_connection_file(self, **kwargs: Any) -> None: # write_connection_file also sets default ports: self._record_random_port_names() for name in port_names: - setattr(self, name, cfg[name]) + setattr(self, name, cast(int, cfg.get(name))) self._connection_file_written = True @@ -548,23 +584,25 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None: See the connection_file spec for details. """ self.transport = info.get("transport", self.transport) - self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment] + self.ip = info.get("ip", self._ip_default()) self._record_random_port_names() for name in port_names: if getattr(self, name) == 0 and name in info: # not overridden by config or cl_args - setattr(self, name, info[name]) + setattr(self, name, cast(int, info.get(name))) if "key" in info: key = info["key"] - if isinstance(key, str): - key = key.encode() - assert isinstance(key, bytes) - - self.session.key = key + key_bytes = key if isinstance(key, bytes) else key.encode() # type: ignore[redundant-expr,unreachable] + self.session.key = key_bytes if "signature_scheme" in info: self.session.signature_scheme = info["signature_scheme"] + if "curve_publickey" in info and "curve_secretkey" in info: + pub = info["curve_publickey"] + sec = info["curve_secretkey"] + self.curve_publickey = pub.encode() if isinstance(pub, str) else pub # type: ignore[redundant-expr] + self.curve_secretkey = sec.encode() if isinstance(sec, str) else sec # type: ignore[redundant-expr] def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None: """Reconciles the connection information returned from the Provisioner. @@ -618,6 +656,8 @@ def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) pertinent_keys = [ "key", + "curve_publickey", + "curve_secretkey", "ip", "stdin_port", "iopub_port", @@ -657,6 +697,14 @@ def _create_connected_socket( sock.linger = 1000 if identity: sock.identity = identity + if self.curve_publickey is not None: + # The connection file already carries this keypair, so reusing it + # avoids introducing an additional key-distribution mechanism here. + # curve_serverkey authenticates the server; the keypair configures + # encrypted communication for the client socket. + sock.curve_secretkey = self.curve_secretkey + sock.curve_publickey = self.curve_publickey + sock.curve_serverkey = self.curve_publickey sock.connect(url) return sock diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 59fa817f..24e604e9 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -140,6 +140,16 @@ def _context_default(self) -> zmq.Context: ) client_factory: Type = Type(klass=KernelClient, config=True) + transport_encryption: Bool = Bool( + False, + config=True, + help=( + "Enable transport encryption using manager-side provisioning of CurveZMQ server keys for kernels. " + "When True, the provisioner launch path issues and writes Curve credentials " + "before the kernel process starts." + ), + ) + @default("client_factory") def _client_factory_default(self) -> Type: return import_item(self.client_class) @@ -379,7 +389,7 @@ def _close_control_socket(self) -> None: self._control_socket = None async def _async_pre_start_kernel( - self, **kw: t.Any + self, *, transport_encryption: bool | None = None, **kw: t.Any ) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]: """Prepares a kernel for startup in a separate process. @@ -393,6 +403,8 @@ async def _async_pre_start_kernel( and launching the kernel (e.g. Popen kwargs). """ self.shutting_down = False + if transport_encryption is not None: + self.transport_encryption = transport_encryption self.kernel_id = self.kernel_id or kw.pop("kernel_id", str(uuid.uuid4())) # save kwargs for use in restart # assigning Traitlets Dicts to Dict make mypy unhappy but is ok diff --git a/jupyter_client/provisioning/local_provisioner.py b/jupyter_client/provisioning/local_provisioner.py index 20e4802f..40847780 100644 --- a/jupyter_client/provisioning/local_provisioner.py +++ b/jupyter_client/provisioning/local_provisioner.py @@ -9,6 +9,8 @@ import sys from typing import TYPE_CHECKING, Any +import zmq + from ..connect import KernelConnectionInfo, LocalPortCache from ..launcher import launch_kernel from ..localinterfaces import is_local_ip, local_ips @@ -174,6 +176,11 @@ async def pre_launch(self, **kwargs: Any) -> dict[str, Any]: # This should be considered temporary until a better division of labor can be defined. km = self.parent if km: + transport_encryption = bool( + kwargs.pop("transport_encryption", getattr(km, "transport_encryption", False)) + ) + curve_publickey: bytes | None = None + curve_secretkey: bytes | None = None if km.transport == "tcp" and not is_local_ip(km.ip): msg = ( "Can only launch a kernel on a local interface. " @@ -196,11 +203,23 @@ async def pre_launch(self, **kwargs: Any) -> dict[str, Any]: km.hb_port = lpc.find_available_port(km.ip) km.control_port = lpc.find_available_port(km.ip) self.ports_cached = True + + if transport_encryption: + curve_publickey, curve_secretkey = zmq.curve_keypair() + km.curve_publickey = curve_publickey + km.curve_secretkey = curve_secretkey if "env" in kwargs: jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "") - km.write_connection_file(jupyter_session=jupyter_session) + km.write_connection_file( + jupyter_session=jupyter_session, + curve_publickey=curve_publickey, + curve_secretkey=curve_secretkey, + ) else: - km.write_connection_file() + km.write_connection_file( + curve_publickey=curve_publickey, + curve_secretkey=curve_secretkey, + ) self.connection_info = km.get_connection_info() kernel_cmd = km.format_kernel_cmd( diff --git a/pyproject.toml b/pyproject.toml index 9eb1312d..488e5d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "python-dateutil>=2.8.2", "pyzmq>=25.0", "tornado>=6.4.1", + "typing-extensions>=4.13.0", "traitlets>=5.3", ] diff --git a/tests/test_connect.py b/tests/test_connect.py index 148973eb..bee93692 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -84,6 +84,21 @@ def test_write_connection_file(): assert info == sample_info +def test_write_connection_file_normalizes_curve_key_kwargs_to_strings(): + with TemporaryDirectory() as d: + cf = os.path.join(d, "kernel.json") + _fname, cfg = connect.write_connection_file( + cf, + **sample_info, + curve_publickey=b"A" * 40, + curve_secretkey=b"B" * 40, + ) + + assert isinstance(cfg["key"], str) + assert isinstance(cfg["curve_publickey"], str) + assert isinstance(cfg["curve_secretkey"], str) + + def test_load_connection_file_session(): """test load_connection_file() after""" session = Session() @@ -291,3 +306,23 @@ def test_reconcile_connection_info(file_exists, km_matches): km._reconcile_connection_info(provisioner_info) km_info = km.get_connection_info() assert km._equal_connections(km_info, provisioner_info) + + +def test_reconcile_connection_info_with_curve_keys(): + with TemporaryDirectory() as connection_dir: + cf = os.path.join(connection_dir, "kernel.json") + km = KernelManager() + km.connection_file = cf + + _, provisioner_info = connect.write_connection_file( + cf, + **sample_info, + curve_publickey=b"A" * 40, + curve_secretkey=b"B" * 40, + ) + provisioner_info["key"] = provisioner_info["key"].encode() # type:ignore + + km.load_connection_info(provisioner_info) + km._reconcile_connection_info(provisioner_info) + km_info = km.get_connection_info() + assert km._equal_connections(km_info, provisioner_info) diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index 87cd1036..619f9e53 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -101,9 +101,22 @@ async def pre_launch(self, **kwargs: Any) -> dict[str, Any]: km._launch_args = kwargs.copy() # build the Popen cmd extra_arguments = kwargs.pop("extra_arguments", []) - + transport_encryption = bool( + kwargs.pop("transport_encryption", getattr(km, "transport_encryption", False)) + ) + curve_publickey: bytes | None = None + curve_secretkey: bytes | None = None + if transport_encryption: + import zmq + + curve_publickey, curve_secretkey = zmq.curve_keypair() + km.curve_publickey = curve_publickey + km.curve_secretkey = curve_secretkey # write connection file / get default ports - km.write_connection_file() + km.write_connection_file( + curve_publickey=curve_publickey, + curve_secretkey=curve_secretkey, + ) self.connection_info = km.get_connection_info() kernel_cmd = km.format_kernel_cmd( @@ -264,6 +277,7 @@ async def akm_test(self, kernel_mgr): """Starts a kernel, validates the associated provisioner's config, shuts down kernel""" assert kernel_mgr.provisioner is None + if kernel_mgr.kernel_name == "missing_provisioner": with pytest.raises(NoSuchKernel): await kernel_mgr.start_kernel() @@ -276,6 +290,26 @@ async def akm_test(self, kernel_mgr): assert kernel_mgr.provisioner is not None assert kernel_mgr.provisioner.has_process is False + async def test_local_provisioner_pre_launch_generates_curve_keys_under_transport_encryption( + self, monkeypatch, tmp_path + ): + """When transport encryption is enabled, LocalProvisioner seeds curve keys before launch.""" + km = AsyncKernelManager(connection_file=str(tmp_path / "kernel.json")) + km.transport_encryption = True + await km._async_pre_start_kernel() + assert km.provisioner is not None + assert isinstance(km.provisioner, LocalProvisioner) + + monkeypatch.setattr( + "jupyter_client.provisioning.local_provisioner.zmq.curve_keypair", + lambda: (b"A" * 40, b"B" * 40), + ) + + await km.provisioner.pre_launch() + + assert km.provisioner.connection_info["curve_publickey"] == "A" * 40 + assert km.provisioner.connection_info["curve_secretkey"] == "B" * 40 + async def test_existing(self, kpf, akm): await self.akm_test(akm) diff --git a/tests/test_transport_security.py b/tests/test_transport_security.py new file mode 100644 index 00000000..e632060c --- /dev/null +++ b/tests/test_transport_security.py @@ -0,0 +1,295 @@ +"""Tests for the ZMQ transport security.""" + +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import pytest +import zmq + +from jupyter_client import KernelManager +from jupyter_client.channels import HBChannel +from jupyter_client.client import KernelClient +from jupyter_client.connect import ConnectionFileMixin +from jupyter_client.session import Session + + +@pytest.mark.parametrize( + "transport_encryption", + [ + False, + True, + ], +) +def test_iopub_plaintext_visibility_depends_on_curve(transport_encryption, tmp_path): + """An unauthenticated subscriber sees plaintext only when Curve is disabled.""" + + km = KernelManager(connection_file=str(tmp_path / "kernel.json")) + km.cache_ports = False + km.transport_encryption = transport_encryption + km.pre_start_kernel() + + session = Session(key=b"secret-hmac-key") + server = km.context.socket(zmq.XPUB) + eavesdropper_sock = km.context.socket(zmq.SUB) + eavesdropper_sock.setsockopt(zmq.SUBSCRIBE, b"") + + expect_plaintext_visible = not transport_encryption + + server_info = km.get_connection_info() + if "curve_publickey" in server_info and "curve_secretkey" in server_info: + server.curve_secretkey = server_info["curve_secretkey"].encode() + server.curve_publickey = server_info["curve_publickey"].encode() + server.curve_server = True + + try: + server.bind(f"tcp://{km.ip}:{km.iopub_port}") + + # Eavesdropper connects with no authentication or curve keys. + eavesdropper_sock.connect(f"tcp://{km.ip}:{km.iopub_port}") + + # In non-Curve mode, XPUB receives the subscription and we drain it. + # In Curve mode, unauthenticated peers are rejected so no event arrives. + sub_poller = zmq.Poller() + sub_poller.register(server, zmq.POLLIN) + sub_events = dict(sub_poller.poll(timeout=1000)) + if server in sub_events: + server.recv() # discard subscription frame + + # Simulate a kernel publishing stream messages via Session + # (HMAC-signed, but not encrypted at the transport layer). + sensitive_content = {"name": "stdout", "text": "top_secret_output_12345"} + session.send(server, "stream", sensitive_content, ident=b"kernel.stream.stdout") + + # Check if an unauthenticated subscriber can read plaintext payload, + # using the same event polling behavior in both modes. + recv_poller = zmq.Poller() + recv_poller.register(eavesdropper_sock, zmq.POLLIN) + events = dict(recv_poller.poll(timeout=1000)) + + did_receive = eavesdropper_sock in events + assert did_receive is expect_plaintext_visible, ( + "Unexpected unauthenticated visibility result for IOPub payload" + ) + if expect_plaintext_visible: + # Demonstrates that the message content is visible in plaintext frames when Curve is disabled. + raw_frames = eavesdropper_sock.recv_multipart() + raw_bytes = b"".join(raw_frames) + assert b"top_secret_output_12345" in raw_bytes, ( + f"Expected plaintext content in raw frames.\nRaw bytes: {raw_bytes!r}" + ) + assert b"stream" in raw_bytes, "msg_type 'stream' should be visible in plaintext frames" + + finally: + server.close(linger=0) + eavesdropper_sock.close(linger=0) + km.cleanup_connection_file() + km.context.term() + + +def test_connect_shell_to_curve_server_with_curve_keys_succeeds(): + """Public API path: load_connection_info + connect_shell works with curve keys.""" + pub, sec = zmq.curve_keypair() + + # Set up a CurveZMQ server socket (simulating the kernel side). + ctx = zmq.Context() + server = ctx.socket(zmq.ROUTER) + server.curve_secretkey = sec + server.curve_publickey = pub + server.curve_server = True + port = server.bind_to_random_port("tcp://127.0.0.1") + + try: + # Configure through the same public parsing path used for + # connection-file content. + info = { + "ip": "127.0.0.1", + "transport": "tcp", + "shell_port": port, + "key": "abc123", + "signature_scheme": "hmac-sha256", + "curve_publickey": pub.decode("ascii"), + "curve_secretkey": sec.decode("ascii"), + } + mixin = ConnectionFileMixin() + mixin.context = ctx + mixin.load_connection_info(info) + + client_sock = mixin.connect_shell() + try: + client_sock.send(b"probe", flags=zmq.NOBLOCK) + + poller = zmq.Poller() + poller.register(server, zmq.POLLIN) + events = dict(poller.poll(timeout=1000)) + assert server in events, ( + "Authenticated client message was not received - " + "connect_shell() did not produce a working Curve-authenticated socket" + ) + finally: + client_sock.close(linger=0) + finally: + server.close(linger=0) + ctx.term() + + +def test_hb_channel_class_without_curve_support_raises_when_curve_is_active(): + """KernelClient.hb_channel raises RuntimeError when the hb_channel_class + does not accept curve_serverkey but CurveZMQ is active.""" + + class LegacyHBChannel(HBChannel): + """Simulates an old heartbeat channel class that predates curve support.""" + + def __init__(self, context, session, address): # type: ignore[override] + super().__init__(context, session, address) + + pub, _sec = zmq.curve_keypair() + + client = KernelClient() + client.hb_channel_class = LegacyHBChannel # type: ignore[assignment] + client.load_connection_info( + { + "ip": "127.0.0.1", + "transport": "tcp", + "hb_port": 5555, + "key": "abc123", + "signature_scheme": "hmac-sha256", + "curve_publickey": pub.decode("ascii"), + "curve_secretkey": pub.decode("ascii"), + } + ) + + with pytest.raises(RuntimeError, match=r"LegacyHBChannel.*curve_serverkey"): + _ = client.hb_channel + + +def test_hb_channel_class_without_curve_support_does_not_raise_when_curve_disabled(): + """KernelClient.hb_channel remains usable with legacy hb_channel_class when Curve is off.""" + + class LegacyHBChannel(HBChannel): + """Simulates an old heartbeat channel class that predates curve support.""" + + def __init__(self, context, session, address): # type: ignore[override] + super().__init__(context, session, address) + + client = KernelClient() + client.hb_channel_class = LegacyHBChannel # type: ignore[assignment] + client.load_connection_info( + { + "ip": "127.0.0.1", + "transport": "tcp", + "hb_port": 5555, + "key": "abc123", + "signature_scheme": "hmac-sha256", + } + ) + + hb = client.hb_channel + assert isinstance(hb, LegacyHBChannel) + + +def test_hb_channel_class_unrelated_typeerror_propagates_unchanged(): + """TypeError unrelated to curve_serverkey is not swallowed or re-wrapped.""" + + class BrokenHBChannel(HBChannel): + def __init__(self, context, session, address, **kwargs): # type: ignore[override] + raise TypeError("totally unrelated constructor error") + + pub, _sec = zmq.curve_keypair() + + client = KernelClient() + client.hb_channel_class = BrokenHBChannel # type: ignore[assignment] + client.load_connection_info( + { + "ip": "127.0.0.1", + "transport": "tcp", + "hb_port": 5555, + "key": "abc123", + "signature_scheme": "hmac-sha256", + "curve_publickey": pub.decode("ascii"), + "curve_secretkey": pub.decode("ascii"), + } + ) + + with pytest.raises(TypeError, match="totally unrelated constructor error"): + _ = client.hb_channel + + +def test_connect_shell_to_curve_server_without_curve_keys_is_rejected(): + """Public API path: without curve keys, shell traffic to a Curve server is dropped.""" + pub, sec = zmq.curve_keypair() + + ctx = zmq.Context() + server = ctx.socket(zmq.ROUTER) + server.curve_secretkey = sec + server.curve_publickey = pub + server.curve_server = True + port = server.bind_to_random_port("tcp://127.0.0.1") + + try: + info = { + "ip": "127.0.0.1", + "transport": "tcp", + "shell_port": port, + "key": "abc123", + "signature_scheme": "hmac-sha256", + } + mixin = ConnectionFileMixin() + mixin.context = ctx + mixin.load_connection_info(info) + + client_sock = mixin.connect_shell() + try: + client_sock.send(b"probe", flags=zmq.NOBLOCK) + poller = zmq.Poller() + poller.register(server, zmq.POLLIN) + events = dict(poller.poll(timeout=300)) + assert server not in events, ( + "Unauthenticated message reached Curve server - expected drop without curve keys" + ) + finally: + client_sock.close(linger=0) + finally: + server.close(linger=0) + ctx.term() + + +def test_connect_shell_to_curve_server_with_wrong_curve_keys_is_rejected(): + """Public API path: mismatched curve keys - traffic to a Curve server is dropped.""" + pub, sec = zmq.curve_keypair() + wrong_pub, wrong_sec = zmq.curve_keypair() + + ctx = zmq.Context() + server = ctx.socket(zmq.ROUTER) + server.curve_secretkey = sec + server.curve_publickey = pub + server.curve_server = True + port = server.bind_to_random_port("tcp://127.0.0.1") + + try: + info = { + "ip": "127.0.0.1", + "transport": "tcp", + "shell_port": port, + "key": "abc123", + "signature_scheme": "hmac-sha256", + "curve_publickey": wrong_pub.decode("ascii"), + "curve_secretkey": wrong_sec.decode("ascii"), + } + mixin = ConnectionFileMixin() + mixin.context = ctx + mixin.load_connection_info(info) + + client_sock = mixin.connect_shell() + try: + client_sock.send(b"probe", flags=zmq.NOBLOCK) + poller = zmq.Poller() + poller.register(server, zmq.POLLIN) + events = dict(poller.poll(timeout=300)) + assert server not in events, ( + "Message with wrong curve keys reached Curve server - expected drop" + ) + finally: + client_sock.close(linger=0) + finally: + server.close(linger=0) + ctx.term()