Skip to content

Commit 7f0552a

Browse files
authored
connectors.ssh: honor ConnectTimeout through ProxyJump (#1679)
1 parent 1362dbb commit 7f0552a

4 files changed

Lines changed: 97 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ license-files = ["LICENSE.md"]
1313
requires-python = ">=3.10,<4.0"
1414
dependencies = [
1515
"gevent>=1.5",
16-
"paramiko>=2.7,<5", # 2.7 (2019) adds OpenSSH key format + Match SSH config
16+
"paramiko>=2.11,<5", # 2.11 (2022) adds Transport.open_channel(timeout=...) for ProxyJump timeout (#971)
1717
"click>2",
1818
"jinja2>3,<4",
1919
"python-dateutil>2,<3",

src/pyinfra/connectors/sshuserclient/client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,14 @@ def connect( # type: ignore[override]
220220
session = transport.open_session()
221221
AgentRequestHandler(session)
222222

223-
def gateway(self, hostname, host_port, target, target_port):
223+
def gateway(self, hostname, host_port, target, target_port, timeout=None):
224224
transport = self.get_transport()
225225
assert transport is not None, "No transport"
226226
return transport.open_channel(
227227
"direct-tcpip",
228228
(target, target_port),
229229
(hostname, host_port),
230+
timeout=timeout,
230231
)
231232

232233
def parse_config(
@@ -284,6 +285,12 @@ def parse_config(
284285
if "port" in host_config:
285286
cfg["port"] = int(host_config["port"])
286287

288+
# Respect ``ConnectTimeout`` from ssh_config (issue #971): without this,
289+
# paramiko waits on its own default and a ProxyJump hop can hang for
290+
# minutes before failing.
291+
if "connecttimeout" in host_config and "timeout" not in cfg:
292+
cfg["timeout"] = int(host_config["connecttimeout"])
293+
287294
if "serveraliveinterval" in host_config:
288295
keep_alive = int(host_config["serveraliveinterval"])
289296

@@ -298,14 +305,26 @@ def parse_config(
298305
elif "proxyjump" in host_config:
299306
hops = host_config["proxyjump"].split(",")
300307
sock = None
308+
# Propagate the target's timeout down so hop connections and the
309+
# direct-tcpip channel don't hang forever when the network misbehaves
310+
# (issue #971). Individual hops can still override via their own
311+
# ``ConnectTimeout`` in ssh_config.
312+
target_timeout = cfg.get("timeout")
301313

302314
for i, hop in enumerate(hops):
303315
hop_hostname, hop_config = self.derive_shorthand(ssh_config, hop)
304316
logger.debug("SSH ProxyJump through %s:%s", hop_hostname, hop_config["port"])
305317

318+
hop_connect_kwargs = dict(hop_config)
319+
if "timeout" not in hop_connect_kwargs and target_timeout is not None:
320+
hop_connect_kwargs["timeout"] = target_timeout
321+
306322
c = SSHClient()
307323
c.connect(
308-
hop_hostname, _pyinfra_ssh_config_file=ssh_config_file, sock=sock, **hop_config
324+
hop_hostname,
325+
_pyinfra_ssh_config_file=ssh_config_file,
326+
sock=sock,
327+
**hop_connect_kwargs,
309328
)
310329

311330
if i == len(hops) - 1:
@@ -314,7 +333,13 @@ def parse_config(
314333
else:
315334
target, target_config = self.derive_shorthand(ssh_config, hops[i + 1])
316335

317-
sock = c.gateway(hostname, cfg["port"], target, target_config["port"])
336+
sock = c.gateway(
337+
hostname,
338+
cfg["port"],
339+
target,
340+
target_config["port"],
341+
timeout=target_timeout,
342+
)
318343
cfg["sock"] = sock
319344

320345
return (
@@ -354,6 +379,8 @@ def derive_shorthand(ssh_config, host_string):
354379
"port": base_config.get("port", 22),
355380
"username": base_config.get("user"),
356381
}
382+
if "connecttimeout" in base_config:
383+
config["timeout"] = int(base_config["connecttimeout"])
357384
config.update(shorthand_config)
358385

359386
return hostname, config

tests/test_connectors/test_sshuserclient.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@
4141
ForwardAgent yes
4242
"""
4343

44+
SSH_CONFIG_PROXYJUMP_CONNECTTIMEOUT = """
45+
Host jump
46+
HostName jump.example.com
47+
User jumpuser
48+
ConnectTimeout 7
49+
50+
Host device
51+
HostName 10.0.0.170
52+
ProxyJump jump
53+
ConnectTimeout 5
54+
User deviceuser
55+
"""
56+
57+
SSH_CONFIG_CONNECTTIMEOUT = """
58+
Host slowhost
59+
HostName slow.example.com
60+
User slowuser
61+
ConnectTimeout 12
62+
"""
63+
4464
SSH_CONFIG_MULTIPLE_KNOWN_HOSTS = """
4565
Host 192.168.1.3
4666
UserKnownHostsFile ~/.ssh/known_hosts ~/.ssh/known_hosts.infra ~/.ssh/known_hosts.webservers
@@ -337,7 +357,51 @@ def test_load_ssh_config_proxyjump(self, fake_gateway, fake_ssh_connect):
337357
sock=None,
338358
username="nottestuser",
339359
)
340-
fake_gateway.assert_called_once_with("192.168.1.2", 1022, "192.168.1.2", 1022)
360+
fake_gateway.assert_called_once_with("192.168.1.2", 1022, "192.168.1.2", 1022, timeout=None)
361+
362+
@patch(
363+
"pyinfra.connectors.sshuserclient.client.open",
364+
mock_open(read_data=SSH_CONFIG_CONNECTTIMEOUT),
365+
create=True,
366+
)
367+
def test_connecttimeout_sets_timeout_kwarg(self):
368+
"""Regression test for #971: ``ConnectTimeout`` in ssh_config must be
369+
propagated so paramiko doesn't hang on its own default."""
370+
client = SSHClient()
371+
_, config, *_ = client.parse_config("slowhost")
372+
assert config.get("timeout") == 12
373+
374+
@patch(
375+
"pyinfra.connectors.sshuserclient.client.open",
376+
mock_open(read_data=SSH_CONFIG_PROXYJUMP_CONNECTTIMEOUT),
377+
create=True,
378+
)
379+
@patch(
380+
"pyinfra.connectors.sshuserclient.config.open",
381+
mock_open(read_data=SSH_CONFIG_PROXYJUMP_CONNECTTIMEOUT),
382+
create=True,
383+
)
384+
@patch("pyinfra.connectors.sshuserclient.SSHClient.connect")
385+
@patch("pyinfra.connectors.sshuserclient.SSHClient.gateway")
386+
def test_proxyjump_propagates_connecttimeout(self, fake_gateway, fake_ssh_connect):
387+
"""Regression test for #971: ``ConnectTimeout`` on both the target and
388+
the hop must be honored so neither the hop connect nor the direct-tcpip
389+
channel can hang forever."""
390+
client = SSHClient()
391+
392+
_, config, *_ = client.parse_config("device")
393+
394+
# Target's ConnectTimeout wins for the channel open.
395+
assert config.get("timeout") == 5
396+
# Hop connect receives the hop's own ConnectTimeout (7s, per its own
397+
# ssh_config block) rather than inheriting the target's 5s.
398+
fake_ssh_connect.assert_called_once()
399+
_, kwargs = fake_ssh_connect.call_args
400+
assert kwargs["timeout"] == 7
401+
# Channel open (gateway) uses the target's ConnectTimeout.
402+
fake_gateway.assert_called_once()
403+
_, gw_kwargs = fake_gateway.call_args
404+
assert gw_kwargs["timeout"] == 5
341405

342406
@patch("pyinfra.connectors.sshuserclient.client.open", mock_open(), create=True)
343407
@patch("pyinfra.connectors.sshuserclient.client.ParamikoClient.connect")

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)