Skip to content

Commit 8798dd1

Browse files
committed
api,connectors: honour operation-level _temp_dir in askpass helpers
Closes #1623. The SUDO_ASKPASS / SU_ASKPASS helper script is now placed under the same directory the caller requested, fixing the case where an operation-level _temp_dir (or a per-host default cascaded through the _temp_dir global argument on host.data) was respected for file ops but ignored by the internal askpass helpers. Resolution order (unchanged outside the askpass path): 1. Operation-level _temp_dir (explicit caller) 2. config.TEMP_DIR / host.data._temp_dir via the existing global-argument cascade 3. TmpDir fact (_get_temp_directory only) 4. config.DEFAULT_TEMP_DIR Changes: - _ensure_askpass_set_for_host accepts a temp_dir override and tracks the resolved directory next to the cached path; if a later call resolves a different temp_dir the cached path is invalidated so the script gets regenerated under the correct directory. - make_unix_command_for_host threads _temp_dir through to the askpass helpers so the op-level override takes effect. Tests cover the askpass path for default, config, and op-level temp directories, cache invalidation on temp_dir change, and the full path through make_unix_command_for_host. Tests use unique hostnames to avoid sharing the process-global memoize cache on _get_temp_directory.
1 parent 0e8c348 commit 8798dd1

2 files changed

Lines changed: 133 additions & 10 deletions

File tree

src/pyinfra/connectors/util.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,18 +280,37 @@ def extract_control_arguments(arguments: "ConnectorArguments") -> "ConnectorArgu
280280
return control_arguments
281281

282282

283-
def _ensure_sudo_askpass_set_for_host(host: "Host"):
284-
return _ensure_askpass_set_for_host(host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR)
283+
def _ensure_sudo_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None):
284+
return _ensure_askpass_set_for_host(
285+
host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir
286+
)
285287

286288

287-
def _ensure_su_askpass_set_for_host(host: "Host"):
288-
return _ensure_askpass_set_for_host(host, "su_askpass_path", SU_ASKPASS_ENV_VAR)
289+
def _ensure_su_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None):
290+
return _ensure_askpass_set_for_host(
291+
host, "su_askpass_path", SU_ASKPASS_ENV_VAR, temp_dir=temp_dir
292+
)
289293

290294

291-
def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
292-
if host.connector_data.get(key):
295+
def _ensure_askpass_set_for_host(
296+
host: "Host", key: str, env_var: str, temp_dir: Optional[str] = None
297+
):
298+
# Operation-level _temp_dir (if any) overrides the host-level/global
299+
# temp directory resolution so `server.shell(..., _temp_dir=X)` places
300+
# the askpass script under X rather than /tmp.
301+
effective_temp_dir = temp_dir or host.get_temp_dir_config()
302+
303+
# Invalidate the cache if the resolved temp_dir changed since the path
304+
# was created, otherwise we'd hand out a stale path under the wrong dir.
305+
# If the tracker is missing (older code path or external population),
306+
# trust the existing path to preserve backward compatibility.
307+
temp_dir_cache_key = "{0}_temp_dir".format(key)
308+
existing_path = host.connector_data.get(key)
309+
existing_temp_dir = host.connector_data.get(temp_dir_cache_key)
310+
if existing_path and (existing_temp_dir is None or existing_temp_dir == effective_temp_dir):
293311
return
294-
ok, output = host.run_shell_command(ASKPASS_COMMAND.format(host.get_temp_dir_config(), env_var))
312+
313+
ok, output = host.run_shell_command(ASKPASS_COMMAND.format(effective_temp_dir, env_var))
295314

296315
if not ok:
297316
raise PyinfraError("Failed to create sudo_askpass command: {0}".format(output.output))
@@ -304,6 +323,7 @@ def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
304323
)
305324

306325
host.connector_data[key] = output.stdout_lines[0]
326+
host.connector_data[temp_dir_cache_key] = effective_temp_dir
307327

308328

309329
def make_unix_command_for_host(
@@ -312,6 +332,11 @@ def make_unix_command_for_host(
312332
command: StringCommand,
313333
**command_arguments,
314334
) -> StringCommand:
335+
# Operation-level temp directory override, if any. Passed through to the
336+
# askpass helpers so the generated SUDO_ASKPASS / SU_ASKPASS script lands
337+
# under the same directory the operation asked for.
338+
op_temp_dir = command_arguments.get("_temp_dir")
339+
315340
# Handle sudo password
316341
if command_arguments.get("_sudo"):
317342
# If the sudo password is not set in the direct arguments,
@@ -321,13 +346,13 @@ def make_unix_command_for_host(
321346

322347
if command_arguments.get("_sudo_password"):
323348
# Ensure the askpass path is correctly set and passed through
324-
_ensure_sudo_askpass_set_for_host(host)
349+
_ensure_sudo_askpass_set_for_host(host, temp_dir=op_temp_dir)
325350
command_arguments["_sudo_askpass_path"] = host.connector_data["sudo_askpass_path"]
326351

327352
# Handle su password
328353
if command_arguments.get("_su_user"):
329354
if command_arguments.get("_su_password"):
330-
_ensure_su_askpass_set_for_host(host)
355+
_ensure_su_askpass_set_for_host(host, temp_dir=op_temp_dir)
331356
command_arguments["_su_askpass_path"] = host.connector_data["su_askpass_path"]
332357

333358
return make_unix_command(command, **command_arguments)

tests/test_connectors/test_util.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# encoding: utf-8
22

33
from unittest import TestCase
4-
from unittest.mock import MagicMock
4+
from unittest.mock import MagicMock, patch
55

66
from pyinfra.api import Config, State
77
from pyinfra.connectors.util import (
8+
CommandOutput,
9+
OutputLine,
10+
_ensure_askpass_set_for_host,
811
make_unix_command,
912
make_unix_command_for_host,
1013
remove_any_sudo_askpass_file,
@@ -280,3 +283,98 @@ def test_noop_when_no_state(self):
280283
remove_any_sudo_askpass_file(host)
281284

282285
host.run_shell_command.assert_not_called()
286+
287+
288+
class TestEnsureAskpassTempDir(TestCase):
289+
"""
290+
The askpass helper must honour the resolved temp directory (issue #1623):
291+
operation-level ``_temp_dir`` > ``config.TEMP_DIR`` > ``config.DEFAULT_TEMP_DIR``.
292+
Per-host defaults for ``_temp_dir`` come through the standard global-argument
293+
cascade (``host.data._temp_dir``), not via a separate code path here.
294+
"""
295+
296+
_counter = 0
297+
298+
@classmethod
299+
def _next_host(cls):
300+
cls._counter += 1
301+
return "askpass-temp-dir-test-host-{0}".format(cls._counter)
302+
303+
def _make_host(self, config=None):
304+
name = self._next_host()
305+
state = State(make_inventory(hosts=(name,)), config or Config())
306+
host = state.inventory.get_host(name)
307+
host.init(state)
308+
return host
309+
310+
def _captured_script_run(self, host, temp_dir=None, stdout="/some/askpass/path"):
311+
"""
312+
Call ``_ensure_askpass_set_for_host`` with ``host.run_shell_command``
313+
patched so we can assert on the remote script text (whose first
314+
argument to the mkstemp template is the temp directory).
315+
"""
316+
captured = {}
317+
318+
def fake_run(command, *args, **kwargs):
319+
captured["command"] = command
320+
return (True, CommandOutput([OutputLine("stdout", stdout)]))
321+
322+
host.run_shell_command = fake_run # type: ignore[method-assign]
323+
_ensure_askpass_set_for_host(
324+
host,
325+
key="sudo_askpass_path",
326+
env_var="PYINFRA_SUDO_PASSWORD",
327+
temp_dir=temp_dir,
328+
)
329+
return captured["command"]
330+
331+
def test_default_temp_dir(self):
332+
host = self._make_host()
333+
script = self._captured_script_run(host)
334+
assert "${TMPDIR:=/tmp}" in script
335+
336+
def test_config_temp_dir(self):
337+
host = self._make_host(Config(TEMP_DIR="/var/tmp"))
338+
script = self._captured_script_run(host)
339+
assert "${TMPDIR:=/var/tmp}" in script
340+
341+
def test_op_temp_dir_wins_over_config(self):
342+
host = self._make_host(Config(TEMP_DIR="/var/tmp"))
343+
script = self._captured_script_run(host, temp_dir="/dev/shm/pyinfra")
344+
assert "${TMPDIR:=/dev/shm/pyinfra}" in script
345+
346+
def test_cache_invalidates_when_temp_dir_changes(self):
347+
host = self._make_host()
348+
first = self._captured_script_run(host, temp_dir="/a", stdout="/a/askpass")
349+
assert "${TMPDIR:=/a}" in first
350+
second = self._captured_script_run(host, temp_dir="/b", stdout="/b/askpass")
351+
assert "${TMPDIR:=/b}" in second
352+
assert host.connector_data["sudo_askpass_path"] == "/b/askpass"
353+
354+
def test_make_unix_command_for_host_threads_temp_dir(self):
355+
host = self._make_host()
356+
host.connector_data["prompted_sudo_password"] = "supersecret"
357+
358+
captured = {}
359+
360+
def fake_run(command, *args, **kwargs):
361+
captured["command"] = command
362+
return (
363+
True,
364+
CommandOutput([OutputLine("stdout", "/op/tmp/pyinfra-askpass-XYZ")]),
365+
)
366+
367+
host.run_shell_command = fake_run # type: ignore[method-assign]
368+
369+
with patch("pyinfra.connectors.util.make_unix_command") as fake_make:
370+
fake_make.return_value = "mocked"
371+
make_unix_command_for_host(
372+
host.state,
373+
host,
374+
"uptime",
375+
_sudo=True,
376+
_sudo_password="supersecret",
377+
_temp_dir="/op/tmp",
378+
)
379+
380+
assert "${TMPDIR:=/op/tmp}" in captured["command"]

0 commit comments

Comments
 (0)