Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions jupyter_client/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
context: zmq.Context | None = None,
session: Session | None = None,
address: t.Union[t.Tuple[str, int], str] = "",
*,
curve_serverkey: bytes | None = None,
) -> None:
"""Create the heartbeat monitor thread.

Expand All @@ -66,12 +68,17 @@ def __init__(
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
curve_serverkey : bytes, optional
CurveZMQ server public key (Z85). When provided, the
heartbeat REQ socket is configured as a CurveZMQ client so it
can communicate with a CurveZMQ-enabled kernel.
Comment thread
krassowski marked this conversation as resolved.
"""
super().__init__()
self.daemon = True

self.context = context
self.session = session
self.curve_serverkey = curve_serverkey
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
Expand Down Expand Up @@ -104,6 +111,13 @@ def _create_socket(self) -> None:
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
if self.curve_serverkey is not None:
# Generate a fresh ephemeral keypair for each socket; only the
# server public key (curve_serverkey) is needed for authentication.
client_pub, client_sec = zmq.curve_keypair()
self.socket.curve_secretkey = client_sec
self.socket.curve_publickey = client_pub
self.socket.curve_serverkey = self.curve_serverkey
assert self.address is not None
self.socket.connect(self.address)

Expand Down
15 changes: 14 additions & 1 deletion jupyter_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,21 @@ def hb_channel(self) -> t.Any:
if self._hb_channel is None:
url = self._make_url("hb")
self.log.debug("connecting heartbeat channel to %s", url)
hb_supports_curve = (
"curve_serverkey" in inspect.signature(self.hb_channel_class.__init__).parameters
)
if self._curve_publickey is not None and not hb_supports_curve:
msg = (
f"{self.hb_channel_class.__name__} does not support the "
"'curve_serverkey' parameter. Upgrade the heartbeat channel "
"class or disable CurveZMQ encryption."
)
raise RuntimeError(msg)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to inspect for hb_supports_curve, since we can pass the argument if it's required and let the standard unsupported argument error raise:

hb_kwargs = {}
if self._curve_publickey:
    hb_kwargs["curve_serverkey"] = self._curve_publickey
...
hb_channel_class(...**hb_kwargs)

but fine if you want to keep the more detailed error. But if you hit that error, a lot of other things are not going to work before we get to the hb channel, I think.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with both; as long as we know there are some backward compatibility story here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does 8b5d6d6 work?

self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract]
self.context, self.session, url
self.context,
self.session,
url,
**({"curve_serverkey": self._curve_publickey} if hb_supports_curve else {}),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That triggers me to want to resume working on a pep for undefined/void parameter that is striped when calling function.

)
return self._hb_channel

Expand Down
75 changes: 62 additions & 13 deletions jupyter_client/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import tempfile
import warnings
from getpass import getpass
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, cast

import zmq
from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
from traitlets.config import LoggingConfigurable, SingletonConfigurable
from typing_extensions import TypedDict

from .localinterfaces import localhost
from .utils import _filefind
Expand All @@ -33,7 +34,22 @@
from .session import Session

# Define custom type for kernel connection info
KernelConnectionInfo = dict[str, Union[int, str, bytes]]


class KernelConnectionInfo(TypedDict, extra_items=str | bytes | int, total=False): # type: ignore[call-arg]
shell_port: int
iopub_port: int
stdin_port: int
control_port: int
hb_port: int
ip: str
key: str
transport: str
signature_scheme: str
kernel_name: str
session: Session
curve_publickey: str
curve_secretkey: str
Comment thread
krassowski marked this conversation as resolved.


def write_connection_file(
Expand All @@ -48,6 +64,8 @@ def write_connection_file(
transport: str = "tcp",
signature_scheme: str = "hmac-sha256",
kernel_name: str = "",
curve_publickey: bytes | None = None,
curve_secretkey: bytes | None = None,
**kwargs: Any,
) -> tuple[str, KernelConnectionInfo]:
"""Generates a JSON config file, including the selection of random ports.
Expand Down Expand Up @@ -76,7 +94,7 @@ def write_connection_file(
ip : str, optional
The ip address the kernel will bind to.

key : str, optional
key : bytes, optional
The Session key used for message authentication.

signature_scheme : str, optional
Expand All @@ -89,6 +107,12 @@ def write_connection_file(

kernel_name : str, optional
The name of the kernel currently connected to.

curve_publickey : bytes, optional
CurveZMQ public key (Z85).

curve_secretkey : bytes, optional
CurveZMQ secret key (Z85).
"""
if not ip:
ip = localhost()
Expand Down Expand Up @@ -149,7 +173,11 @@ def write_connection_file(
cfg["transport"] = transport
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name
cfg.update(kwargs)
if curve_publickey is not None:
cfg["curve_publickey"] = curve_publickey.decode("ascii")
if curve_secretkey is not None:
cfg["curve_secretkey"] = curve_secretkey.decode("ascii")
cfg.update(kwargs) # type: ignore[typeddict-item]

# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
Expand Down Expand Up @@ -318,6 +346,11 @@ def tunnel_to_kernel(
class ConnectionFileMixin(LoggingConfigurable):
"""Mixin for configurable classes that work with connection files"""

# Optional CurveZMQ keys loaded from the connection file (Z85-encoded bytes).
# None when the kernel was not started with CurveZMQ enabled.
_curve_publickey: bytes | None = None
_curve_secretkey: bytes | None = None

data_dir: str | Unicode = Unicode()

def _data_dir_default(self) -> str:
Expand Down Expand Up @@ -405,7 +438,7 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
connect_info : dict
dictionary of connection information.
"""
info = {
info: KernelConnectionInfo = {
"transport": self.transport,
"ip": self.ip,
"shell_port": self.shell_port,
Expand Down Expand Up @@ -515,7 +548,7 @@ def write_connection_file(self, **kwargs: Any) -> None:
# write_connection_file also sets default ports:
self._record_random_port_names()
for name in port_names:
setattr(self, name, cfg[name])
setattr(self, name, cast(int, cfg.get(name)))

self._connection_file_written = True

Expand Down Expand Up @@ -548,23 +581,25 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None:
See the connection_file spec for details.
"""
self.transport = info.get("transport", self.transport)
self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
self.ip = info.get("ip", self._ip_default())

self._record_random_port_names()
for name in port_names:
if getattr(self, name) == 0 and name in info:
# not overridden by config or cl_args
setattr(self, name, info[name])
setattr(self, name, cast(int, info.get(name)))

if "key" in info:
key = info["key"]
if isinstance(key, str):
key = key.encode()
assert isinstance(key, bytes)

self.session.key = key
key_bytes = key if isinstance(key, bytes) else key.encode() # type: ignore[redundant-expr,unreachable]
self.session.key = key_bytes
if "signature_scheme" in info:
self.session.signature_scheme = info["signature_scheme"]
if "curve_publickey" in info and "curve_secretkey" in info:
pub = info["curve_publickey"]
sec = info["curve_secretkey"]
self._curve_publickey = pub.encode()
self._curve_secretkey = sec.encode()

def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
"""Reconciles the connection information returned from the Provisioner.
Expand All @@ -589,6 +624,10 @@ def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
# Prior to the following comparison, we need to adjust the value of "key" to
# be bytes, otherwise the comparison below will fail.
file_info["key"] = file_info["key"].encode()
if "curve_publickey" in file_info:
file_info["curve_publickey"] = file_info["curve_publickey"].encode()
if "curve_secretkey" in file_info:
file_info["curve_secretkey"] = file_info["curve_secretkey"].encode()
if not self._equal_connections(info, file_info):
os.remove(self.connection_file) # Contents mismatch - remove the file
self._connection_file_written = False
Expand Down Expand Up @@ -618,6 +657,8 @@ def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo)

pertinent_keys = [
"key",
"curve_publickey",
"curve_secretkey",
"ip",
"stdin_port",
"iopub_port",
Expand Down Expand Up @@ -657,6 +698,14 @@ def _create_connected_socket(
sock.linger = 1000
if identity:
sock.identity = identity
if self._curve_publickey is not None:
# The connection file already carries this keypair, so reusing it
# avoids introducing an additional key-distribution mechanism here.
# curve_serverkey authenticates the server; the keypair configures
# encrypted communication for the client socket.
sock.curve_secretkey = self._curve_secretkey
sock.curve_publickey = self._curve_publickey
sock.curve_serverkey = self._curve_publickey
sock.connect(url)
return sock

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"python-dateutil>=2.8.2",
"pyzmq>=25.0",
"tornado>=6.4.1",
"typing-extensions>=4.13.0",
"traitlets>=5.3",
]

Expand Down
15 changes: 15 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def test_write_connection_file():
assert info == sample_info


def test_write_connection_file_normalizes_curve_key_kwargs_to_strings():
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
_fname, cfg = connect.write_connection_file(
cf,
**sample_info,
curve_publickey=b"A" * 40,
curve_secretkey=b"B" * 40,
)

assert isinstance(cfg["key"], str)
assert isinstance(cfg["curve_publickey"], str)
assert isinstance(cfg["curve_secretkey"], str)


def test_load_connection_file_session():
"""test load_connection_file() after"""
session = Session()
Expand Down
Loading
Loading