Skip to content

Commit 545d2b8

Browse files
Tbruno25tj-bruno
andauthored
fix: Add timeout support in FunctionCommand (#1663)
Co-authored-by: TJ Bruno <tj.bruno@tensordyne.ai>
1 parent 09c07a4 commit 545d2b8

3 files changed

Lines changed: 25 additions & 2 deletions

File tree

src/pyinfra/api/operations.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> bool:
103103

104104
if isinstance(command, FunctionCommand):
105105
try:
106-
status = command.execute(state, host, connector_arguments)
106+
with gevent.Timeout(timeout, exception=TimeoutError):
107+
status = command.execute(state, host, connector_arguments)
108+
109+
except TimeoutError as e:
110+
log_host_command_error(host, e, timeout=timeout)
111+
107112
except NestedOperationError:
108113
host.log_styled("Error in nested operation", fg="red", log_func=logger.error)
109114
except Exception as e:

src/pyinfra/api/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def log_error_or_warning(
261261

262262

263263
def log_host_command_error(host: "Host", e: Exception, timeout: int | None = 0) -> None:
264-
if isinstance(e, timeout_error):
264+
if isinstance(e, (TimeoutError, timeout_error)):
265265
logger.error(
266266
"{0}{1}".format(
267267
host.print_prefix,

tests/test_api/test_api_operations.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from os import path
33
from unittest import TestCase
44
from unittest.mock import mock_open, patch
5+
import time
56

67
import pyinfra
78
from pyinfra.api import (
@@ -258,6 +259,23 @@ def mocked_function(*args, **kwargs):
258259

259260
assert is_called
260261

262+
def test_function_call_op_timeout(self):
263+
inventory = make_inventory()
264+
state = State(inventory, Config())
265+
state.current_stage = StateStage.Prepare
266+
connect_all(state)
267+
268+
timeout = 1
269+
270+
def mocked_function(*args, **kwargs):
271+
time.sleep(timeout + 1)
272+
273+
add_op(state, python.call, mocked_function, _timeout=timeout)
274+
275+
# Timeout should cause the operation to fail and hosts to be removed
276+
with self.assertRaises(PyinfraError) as context:
277+
run_ops(state)
278+
261279
def test_run_once_serial_op(self):
262280
inventory = make_inventory()
263281
state = State(inventory, Config())

0 commit comments

Comments
 (0)