Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions jupyter_client/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
context: zmq.Context | None = None,
session: Session | None = None,
address: t.Union[t.Tuple[str, int], str] = "",
*,
curve_serverkey: bytes | None = None,
) -> None:
"""Create the heartbeat monitor thread.

Expand All @@ -66,12 +68,17 @@ def __init__(
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
curve_serverkey : bytes, optional
CurveZMQ server public key (Z85). When provided, the
heartbeat REQ socket is configured as a CurveZMQ client so it
can communicate with a CurveZMQ-enabled kernel.
Comment thread
krassowski marked this conversation as resolved.
"""
super().__init__()
self.daemon = True

self.context = context
self.session = session
self.curve_serverkey = curve_serverkey
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
Expand Down Expand Up @@ -104,6 +111,13 @@ def _create_socket(self) -> None:
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
if self.curve_serverkey is not None:
# Generate a fresh ephemeral keypair for each socket; only the
# server public key (curve_serverkey) is needed for authentication.
client_pub, client_sec = zmq.curve_keypair()
self.socket.curve_secretkey = client_sec
self.socket.curve_publickey = client_pub
self.socket.curve_serverkey = self.curve_serverkey
assert self.address is not None
self.socket.connect(self.address)

Expand Down
23 changes: 20 additions & 3 deletions jupyter_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,26 @@ def hb_channel(self) -> t.Any:
if self._hb_channel is None:
url = self._make_url("hb")
self.log.debug("connecting heartbeat channel to %s", url)
self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract]
self.context, self.session, url
)
hb_kwargs = {}
if self.curve_publickey:
hb_kwargs["curve_serverkey"] = self.curve_publickey
try:
self._hb_channel = self.hb_channel_class( # type:ignore[call-arg,abstract]
self.context,
self.session,
url,
**hb_kwargs,
)
except TypeError as e:
if "curve_serverkey" in str(e):
msg = (
f"{self.hb_channel_class.__name__} does not support the "
"'curve_serverkey' parameter. Upgrade the heartbeat channel "
"class or disable CurveZMQ encryption."
)
raise RuntimeError(msg) from e
else:
raise
Comment thread
krassowski marked this conversation as resolved.
return self._hb_channel

@property
Expand Down
76 changes: 62 additions & 14 deletions jupyter_client/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import tempfile
import warnings
from getpass import getpass
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, cast

import zmq
from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
from traitlets import Bool, Bytes, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
from traitlets.config import LoggingConfigurable, SingletonConfigurable
from typing_extensions import TypedDict

from .localinterfaces import localhost
from .utils import _filefind
Expand All @@ -33,7 +34,22 @@
from .session import Session

# Define custom type for kernel connection info
KernelConnectionInfo = dict[str, Union[int, str, bytes]]


class KernelConnectionInfo(TypedDict, extra_items=str | bytes | int, total=False): # type: ignore[call-arg]
shell_port: int
iopub_port: int
stdin_port: int
control_port: int
hb_port: int
ip: str
key: str
transport: str
signature_scheme: str
kernel_name: str
session: Session
curve_publickey: str
curve_secretkey: str
Comment thread
krassowski marked this conversation as resolved.


def write_connection_file(
Expand All @@ -48,6 +64,8 @@ def write_connection_file(
transport: str = "tcp",
signature_scheme: str = "hmac-sha256",
kernel_name: str = "",
curve_publickey: bytes | None = None,
curve_secretkey: bytes | None = None,
**kwargs: Any,
) -> tuple[str, KernelConnectionInfo]:
"""Generates a JSON config file, including the selection of random ports.
Expand Down Expand Up @@ -76,7 +94,7 @@ def write_connection_file(
ip : str, optional
The ip address the kernel will bind to.

key : str, optional
key : bytes, optional
The Session key used for message authentication.

signature_scheme : str, optional
Expand All @@ -89,6 +107,12 @@ def write_connection_file(

kernel_name : str, optional
The name of the kernel currently connected to.

curve_publickey : bytes, optional
CurveZMQ public key (Z85).

curve_secretkey : bytes, optional
CurveZMQ secret key (Z85).
"""
if not ip:
ip = localhost()
Expand Down Expand Up @@ -149,7 +173,11 @@ def write_connection_file(
cfg["transport"] = transport
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name
cfg.update(kwargs)
if curve_publickey is not None:
cfg["curve_publickey"] = curve_publickey.decode("ascii")
if curve_secretkey is not None:
cfg["curve_secretkey"] = curve_secretkey.decode("ascii")
cfg.update(kwargs) # type: ignore[typeddict-item]

# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
Expand Down Expand Up @@ -371,6 +399,11 @@ def _ip_changed(self, change: Any) -> None:
stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")

# Optional CurveZMQ keys loaded from the connection file (Z85-encoded bytes).
# None when the kernel was not started with CurveZMQ enabled.
curve_publickey: Bytes | None = Bytes(allow_none=True, default_value=None)
curve_secretkey: Bytes | None = Bytes(allow_none=True, default_value=None)

# names of the ports with random assignment
_random_port_names: list[str] | None = None

Expand Down Expand Up @@ -405,7 +438,7 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
connect_info : dict
dictionary of connection information.
"""
info = {
info: KernelConnectionInfo = {
"transport": self.transport,
"ip": self.ip,
"shell_port": self.shell_port,
Expand All @@ -426,6 +459,9 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
"key": self.session.key,
}
)
if self.curve_publickey is not None and self.curve_secretkey is not None:
info["curve_publickey"] = self.curve_publickey.decode()
info["curve_secretkey"] = self.curve_secretkey.decode()
return info

# factory for blocking clients
Expand Down Expand Up @@ -515,7 +551,7 @@ def write_connection_file(self, **kwargs: Any) -> None:
# write_connection_file also sets default ports:
self._record_random_port_names()
for name in port_names:
setattr(self, name, cfg[name])
setattr(self, name, cast(int, cfg.get(name)))

self._connection_file_written = True

Expand Down Expand Up @@ -548,23 +584,25 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None:
See the connection_file spec for details.
"""
self.transport = info.get("transport", self.transport)
self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
self.ip = info.get("ip", self._ip_default())

self._record_random_port_names()
for name in port_names:
if getattr(self, name) == 0 and name in info:
# not overridden by config or cl_args
setattr(self, name, info[name])
setattr(self, name, cast(int, info.get(name)))

if "key" in info:
key = info["key"]
if isinstance(key, str):
key = key.encode()
assert isinstance(key, bytes)

self.session.key = key
key_bytes = key if isinstance(key, bytes) else key.encode() # type: ignore[redundant-expr,unreachable]
self.session.key = key_bytes
if "signature_scheme" in info:
self.session.signature_scheme = info["signature_scheme"]
if "curve_publickey" in info and "curve_secretkey" in info:
pub = info["curve_publickey"]
sec = info["curve_secretkey"]
self.curve_publickey = pub.encode() if isinstance(pub, str) else pub # type: ignore[redundant-expr]
self.curve_secretkey = sec.encode() if isinstance(sec, str) else sec # type: ignore[redundant-expr]

def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
"""Reconciles the connection information returned from the Provisioner.
Expand Down Expand Up @@ -618,6 +656,8 @@ def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo)

pertinent_keys = [
"key",
"curve_publickey",
"curve_secretkey",
"ip",
"stdin_port",
"iopub_port",
Expand Down Expand Up @@ -657,6 +697,14 @@ def _create_connected_socket(
sock.linger = 1000
if identity:
sock.identity = identity
if self.curve_publickey is not None:
# The connection file already carries this keypair, so reusing it
# avoids introducing an additional key-distribution mechanism here.
# curve_serverkey authenticates the server; the keypair configures
# encrypted communication for the client socket.
sock.curve_secretkey = self.curve_secretkey
sock.curve_publickey = self.curve_publickey
sock.curve_serverkey = self.curve_publickey
Comment thread
krassowski marked this conversation as resolved.
sock.connect(url)
return sock

Expand Down
14 changes: 13 additions & 1 deletion jupyter_client/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def _context_default(self) -> zmq.Context:
)
client_factory: Type = Type(klass=KernelClient, config=True)

transport_encryption: Bool = Bool(
False,
config=True,
help=(
"Enable transport encryption using manager-side provisioning of CurveZMQ server keys for kernels. "
"When True, the provisioner launch path issues and writes Curve credentials "
"before the kernel process starts."
),
)

@default("client_factory")
def _client_factory_default(self) -> Type:
return import_item(self.client_class)
Expand Down Expand Up @@ -379,7 +389,7 @@ def _close_control_socket(self) -> None:
self._control_socket = None

async def _async_pre_start_kernel(
self, **kw: t.Any
self, *, transport_encryption: bool | None = None, **kw: t.Any
) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]:
"""Prepares a kernel for startup in a separate process.

Expand All @@ -393,6 +403,8 @@ async def _async_pre_start_kernel(
and launching the kernel (e.g. Popen kwargs).
"""
self.shutting_down = False
if transport_encryption is not None:
self.transport_encryption = transport_encryption
self.kernel_id = self.kernel_id or kw.pop("kernel_id", str(uuid.uuid4()))
# save kwargs for use in restart
# assigning Traitlets Dicts to Dict make mypy unhappy but is ok
Expand Down
23 changes: 21 additions & 2 deletions jupyter_client/provisioning/local_provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import sys
from typing import TYPE_CHECKING, Any

import zmq

from ..connect import KernelConnectionInfo, LocalPortCache
from ..launcher import launch_kernel
from ..localinterfaces import is_local_ip, local_ips
Expand Down Expand Up @@ -174,6 +176,11 @@ async def pre_launch(self, **kwargs: Any) -> dict[str, Any]:
# This should be considered temporary until a better division of labor can be defined.
km = self.parent
if km:
transport_encryption = bool(
kwargs.pop("transport_encryption", getattr(km, "transport_encryption", False))
)
Comment on lines +179 to +181
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking for myself,

It that bool for type annotation to not fail ?
Should be more conservative and check that it actually is a bool (or None), we get back, and not any other non-falsy vallue ? Or are we thinking the ks.transport_encryption could one day become Enums of different type of encryption ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or are we thinking the ks.transport_encryption could one day become Enums of different type of encryption

Yes, this is exactly why I called it transport_encryption (rather than enable_transport_encryption or transport_encryption_on).

Should be more conservative and check that it actually is a bool (or None), we get back, and not any other non-falsy vallue

Maybe for now? I do not have a strong opinion but happy to make this change.

curve_publickey: bytes | None = None
curve_secretkey: bytes | None = None
if km.transport == "tcp" and not is_local_ip(km.ip):
msg = (
"Can only launch a kernel on a local interface. "
Expand All @@ -196,11 +203,23 @@ async def pre_launch(self, **kwargs: Any) -> dict[str, Any]:
km.hb_port = lpc.find_available_port(km.ip)
km.control_port = lpc.find_available_port(km.ip)
self.ports_cached = True

if transport_encryption:
curve_publickey, curve_secretkey = zmq.curve_keypair()
km.curve_publickey = curve_publickey
km.curve_secretkey = curve_secretkey
if "env" in kwargs:
jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
km.write_connection_file(jupyter_session=jupyter_session)
km.write_connection_file(
jupyter_session=jupyter_session,
curve_publickey=curve_publickey,
curve_secretkey=curve_secretkey,
)
else:
km.write_connection_file()
km.write_connection_file(
curve_publickey=curve_publickey,
curve_secretkey=curve_secretkey,
)
Comment thread
Carreau marked this conversation as resolved.
self.connection_info = km.get_connection_info()

kernel_cmd = km.format_kernel_cmd(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"python-dateutil>=2.8.2",
"pyzmq>=25.0",
"tornado>=6.4.1",
"typing-extensions>=4.13.0",
"traitlets>=5.3",
]

Expand Down
35 changes: 35 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def test_write_connection_file():
assert info == sample_info


def test_write_connection_file_normalizes_curve_key_kwargs_to_strings():
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
_fname, cfg = connect.write_connection_file(
cf,
**sample_info,
curve_publickey=b"A" * 40,
curve_secretkey=b"B" * 40,
)

assert isinstance(cfg["key"], str)
assert isinstance(cfg["curve_publickey"], str)
assert isinstance(cfg["curve_secretkey"], str)


def test_load_connection_file_session():
"""test load_connection_file() after"""
session = Session()
Expand Down Expand Up @@ -291,3 +306,23 @@ def test_reconcile_connection_info(file_exists, km_matches):
km._reconcile_connection_info(provisioner_info)
km_info = km.get_connection_info()
assert km._equal_connections(km_info, provisioner_info)


def test_reconcile_connection_info_with_curve_keys():
with TemporaryDirectory() as connection_dir:
cf = os.path.join(connection_dir, "kernel.json")
km = KernelManager()
km.connection_file = cf

_, provisioner_info = connect.write_connection_file(
cf,
**sample_info,
curve_publickey=b"A" * 40,
curve_secretkey=b"B" * 40,
)
provisioner_info["key"] = provisioner_info["key"].encode() # type:ignore

km.load_connection_info(provisioner_info)
km._reconcile_connection_info(provisioner_info)
km_info = km.get_connection_info()
assert km._equal_connections(km_info, provisioner_info)
Loading
Loading