From 5ff5046b7e64f586111b6178f39e101f961f98aa Mon Sep 17 00:00:00 2001 From: Plamen Dimitrov Date: Fri, 29 May 2026 20:38:44 +0800 Subject: [PATCH 1/2] Allow retries for all major remote copy methods For certain categories of transient network errors allowing more resilience to remote copy operations is highly desirable where the copy should at least be retried a configurable number of times in (standard) one second intervals. Add corresponding test contract and reorder the test cases in sync with the original API funtions ordering. Signed-off-by: Plamen Dimitrov --- aexpect/remote.py | 191 ++++++++++++++++++++++++++++++++++++------- tests/test_remote.py | 84 ++++++++++++++++--- 2 files changed, 234 insertions(+), 41 deletions(-) diff --git a/aexpect/remote.py b/aexpect/remote.py index efc9e61..8385a4a 100644 --- a/aexpect/remote.py +++ b/aexpect/remote.py @@ -662,6 +662,7 @@ def remote_copy( log_function=None, transfer_timeout=600, login_timeout=300, + tries=1, ): """ Transfer files using rsync or SCP, given a command line. @@ -677,25 +678,67 @@ def remote_copy( :param login_timeout: The maximal time duration (in seconds) to wait for each step of the login procedure (i.e. the "Are you sure" prompt or the password prompt) + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ - LOG.debug( - "Trying to copy with command '%s', timeout %ss", - command, - transfer_timeout, - ) - if log_filename: - output_func = log_function - output_params = (log_filename,) - else: - output_func = None - output_params = () method = "rsync" if "rsync" in command else "scp" - with Expect( - command, output_func=output_func, output_params=output_params - ) as session: - _remote_copy( - session, password_list, transfer_timeout, login_timeout, method - ) + + for attempt in range(tries): + try: + LOG.debug( + "Trying to copy with command '%s', timeout %ss (attempt %d/%d)", + command, + transfer_timeout, + attempt + 1, + tries, + ) + if log_filename: + output_func = log_function + output_params = (log_filename,) + else: + output_func = None + output_params = () + with Expect( + command, output_func=output_func, output_params=output_params + ) as session: + _remote_copy( + session, + password_list, + transfer_timeout, + login_timeout, + method, + ) + return # transfer is successful + except ( + TransferTimeoutError, + AuthenticationTimeoutError, + ExpectTimeoutError, + ) as error: + if attempt < tries - 1: + LOG.debug( + "Transient error on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise + except (TransferFailedError, SCPError, RsyncError) as error: + # For transfer failures, only retry on specific conditions + if "Connection" in str(error) or "timeout" in str(error).lower(): + if attempt < tries - 1: + LOG.debug( + "Connection error on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise + else: + raise def scp_to_remote( @@ -711,6 +754,7 @@ def scp_to_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files to a remote host (guest) through scp. @@ -729,6 +773,8 @@ def scp_to_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if limit: limit = f"-l {limit}" @@ -753,7 +799,12 @@ def scp_to_remote( ) password_list = [password] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -770,6 +821,7 @@ def scp_from_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files from a remote host (guest). @@ -788,6 +840,8 @@ def scp_from_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if limit: limit = f"-l {limit}" @@ -810,7 +864,14 @@ def scp_from_remote( rf"{shlex.quote(local_path)}" ) password_list = [password] - remote_copy(command, password_list, log_filename, log_function, timeout) + remote_copy( + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, + ) def scp_between_remotes( @@ -830,6 +891,7 @@ def scp_between_remotes( timeout=600, src_inter=None, dst_inter=None, + tries=1, ): """ Copy files from a remote host (guest) to another remote host (guest). @@ -851,6 +913,8 @@ def scp_between_remotes( to complete. :param src_inter: The interface on local that the src neighbour attached :param dst_inter: The interface on the src that the dst neighbour attached + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :return: True on success and False on failure. """ @@ -883,7 +947,12 @@ def scp_between_remotes( ) password_list = [s_passwd, d_passwd] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -900,6 +969,7 @@ def rsync_to_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files to a remote host (guest) through rsync. @@ -918,6 +988,8 @@ def rsync_to_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :raise: Whatever remote_rsync() raises """ if limit: @@ -941,7 +1013,12 @@ def rsync_to_remote( ) password_list = [password] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -958,6 +1035,7 @@ def rsync_from_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files from a remote host (guest) through rsync. @@ -976,6 +1054,8 @@ def rsync_from_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :raise: Whatever remote_rsync() raises """ if limit: @@ -997,7 +1077,14 @@ def rsync_from_remote( f"{username}@{host}:{quote_path(remote_path)} {shlex.quote(local_path)}" ) password_list = [password] - remote_copy(command, password_list, log_filename, log_function, timeout) + remote_copy( + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, + ) # noinspection PyBroadException @@ -1284,6 +1371,7 @@ def scp_to_session( log_function=None, timeout=600, interface=None, + tries=1, ): """ Secure copy a filepath (w/o wildcard) to a remote location with the same @@ -1299,6 +1387,8 @@ def scp_to_session( :param log_function: Function to perform logging :param timeout: Timeout for the scp operation :param interface: Interface used for the transfer + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. The rest of the arguments are identical to scp_to_remote(). """ @@ -1315,6 +1405,7 @@ def scp_to_session( log_function, timeout, interface, + tries, ) @@ -1328,6 +1419,7 @@ def scp_from_session( log_function=None, timeout=600, interface=None, + tries=1, ): """ Secure copy a filepath (w/o wildcard) from a remote location with the same @@ -1343,6 +1435,8 @@ def scp_from_session( :param log_function: Function to perform logging :param timeout: Timeout for the scp operation :param interface: Interface used for the transfer + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. The rest of the arguments are identical to scp_from_remote(). """ @@ -1359,6 +1453,7 @@ def scp_from_session( log_function, timeout, interface, + tries, ) @@ -1406,6 +1501,7 @@ def copy_files_to( timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument + tries=1, ): """ Copy files to a remote host (guest) using the selected client. @@ -1427,6 +1523,8 @@ def copy_files_to( :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address.) :param filesize: size of file will be transferred + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if client == "scp": scp_to_remote( @@ -1442,6 +1540,7 @@ def copy_files_to( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rsync": rsync_to_remote( @@ -1457,6 +1556,7 @@ def copy_files_to( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rss": log_func = None @@ -1464,9 +1564,23 @@ def copy_files_to( log_func = LOG.debug if interface: address = f"{address}%{interface}" - fdclient = rss_client.FileUploadClient(address, port, log_func) - fdclient.upload(local_path, remote_path, timeout) - fdclient.close() + for attempt in range(tries): + try: + fdclient = rss_client.FileUploadClient(address, port, log_func) + fdclient.upload(local_path, remote_path, timeout) + fdclient.close() + return # transfer is successful + except Exception as error: # pylint: disable=broad-except + if attempt < tries - 1: + LOG.debug( + "RSS upload failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise else: raise TransferBadClientError(client) @@ -1489,6 +1603,7 @@ def copy_files_from( timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument + tries=1, ): """ Copy files from a remote host (guest) using the selected client. @@ -1510,6 +1625,8 @@ def copy_files_from( :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address.) :param filesize: size of file will be transferred + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if client == "scp": scp_from_remote( @@ -1525,6 +1642,7 @@ def copy_files_from( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rsync": rsync_from_remote( @@ -1540,6 +1658,7 @@ def copy_files_from( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rss": log_func = None @@ -1547,8 +1666,24 @@ def copy_files_from( log_func = LOG.debug if interface: address = f"{address}%{interface}" - fdclient = rss_client.FileDownloadClient(address, port, log_func) - fdclient.download(remote_path, local_path, timeout) - fdclient.close() + for attempt in range(tries): + try: + fdclient = rss_client.FileDownloadClient( + address, port, log_func + ) + fdclient.download(remote_path, local_path, timeout) + fdclient.close() + return # transfer is successful + except Exception as error: # pylint: disable=broad-except + if attempt < tries - 1: + LOG.debug( + "RSS download failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise else: raise TransferBadClientError(client) diff --git a/tests/test_remote.py b/tests/test_remote.py index a590360..65b8cbb 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -59,6 +59,64 @@ def test_wait_for_login(self): " -o PreferredAuthentications=password user@127.0.0.1", ) + @mock.patch("aexpect.remote._remote_copy") + def test_remote_copy(self, mock_remote_copy): + remote.remote_copy("cp a b", ["pass"], "/local/path", "/remote/path") + mock_remote_copy.assert_called_once_with( + mock.ANY, + ["pass"], + 600, + 300, + "scp", + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"cp a b", + ) + + @mock.patch("aexpect.remote._remote_copy") + def test_remote_copy_retry(self, mock_remote_copy): + remote.remote_copy( + "cp a b", + ["pass"], + "/local/path", + "/remote/path", + tries=2, + ) + mock_remote_copy.assert_called_once_with( + mock.ANY, + ["pass"], + 600, + 300, + "scp", + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"cp a b", + ) + + mock_remote_copy.reset_mock() + mock_remote_copy.side_effect = [ + remote.SCPError("Copy failed", "Connection lost"), + None, + ] + remote.remote_copy( + "cp a b", + ["pass"], + "/local/path", + "/remote/path", + tries=2, + ) + self.assertEqual(mock_remote_copy.call_count, 2) + self.assertEqual( + mock_remote_copy.call_args_list[0][0][0].command, + r"cp a b", + ) + self.assertEqual( + mock_remote_copy.call_args_list[1][0][0].command, + r"cp a b", + ) + @mock.patch("aexpect.remote._remote_copy") def test_scp_to_remote(self, mock_remote_copy): remote.scp_to_remote( @@ -85,19 +143,6 @@ def test_scp_from_remote(self, mock_remote_copy): r"scp -r -v -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o PreferredAuthentications=password -P 22 user@\[127.0.0.1\]:/remote/path /local/path", ) - @mock.patch("aexpect.remote._remote_copy") - def test_rsync_to_remote(self, mock_remote_copy): - remote.rsync_to_remote( - "127.0.0.1", 22, "user", "pass", "/local/path", "/remote/path" - ) - mock_remote_copy.assert_called_once_with( - mock.ANY, ["pass"], 600, 300, "rsync" - ) - self.assertEqual( - mock_remote_copy.call_args[0][0].command, - r"rsync -r -avz -e 'ssh -Tp 22 -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' /local/path user@127.0.0.1:/remote/path", - ) - @mock.patch("aexpect.remote._remote_copy") def test_scp_between_remotes(self, mock_remote_copy): remote.scp_between_remotes( @@ -119,6 +164,19 @@ def test_scp_between_remotes(self, mock_remote_copy): r"scp -r -v -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o PreferredAuthentications=password -P 22 src_user@\[src_host\]:/src/path dst_user@\[dst_host\]:/dst/path", ) + @mock.patch("aexpect.remote._remote_copy") + def test_rsync_to_remote(self, mock_remote_copy): + remote.rsync_to_remote( + "127.0.0.1", 22, "user", "pass", "/local/path", "/remote/path" + ) + mock_remote_copy.assert_called_once_with( + mock.ANY, ["pass"], 600, 300, "rsync" + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"rsync -r -avz -e 'ssh -Tp 22 -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' /local/path user@127.0.0.1:/remote/path", + ) + @mock.patch("aexpect.remote._remote_copy") def test_rsync_from_remote(self, mock_remote_copy): remote.rsync_from_remote( From c7e931943fedda9d3e52c5e2f18230a8b00f9f45 Mon Sep 17 00:00:00 2001 From: Plamen Dimitrov Date: Fri, 29 May 2026 23:24:26 +0800 Subject: [PATCH 2/2] feedback: also add retries to ncat and UDP based copying Signed-off-by: Plamen Dimitrov --- aexpect/remote.py | 186 ++++++++++++++++++++++++++++++---------------- 1 file changed, 120 insertions(+), 66 deletions(-) diff --git a/aexpect/remote.py b/aexpect/remote.py index 8385a4a..7df757f 100644 --- a/aexpect/remote.py +++ b/aexpect/remote.py @@ -1107,6 +1107,7 @@ def nc_copy_between_remotes( s_session=None, d_session=None, file_transfer_timeout=600, + tries=1, ): """ Copy files from guest to guest using netcat. @@ -1132,59 +1133,91 @@ def nc_copy_between_remotes( :param d_session: A shell session object for dst or None. :param check_sum: Whether to run checksum for the operation. :param file_transfer_timeout: Timeout for file transfer. - + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :return: True on success and False on failure. """ - check_string = "NCFT" - if not s_session: - s_session = remote_login( - c_type, src, s_port, s_name, s_passwd, c_prompt - ) - if not d_session: - d_session = remote_login( - c_type, dst, s_port, d_name, d_passwd, c_prompt - ) + for attempt in range(tries): + try: + check_string = "NCFT" + if not s_session: + s_session = remote_login( + c_type, src, s_port, s_name, s_passwd, c_prompt + ) + if not d_session: + d_session = remote_login( + c_type, dst, s_port, d_name, d_passwd, c_prompt + ) - try: - s_session.cmd(f"iptables -I INPUT -p {d_protocol} -j ACCEPT") - d_session.cmd(f"iptables -I OUTPUT -p {d_protocol} -j ACCEPT") - except Exception: # pylint: disable=W0703 - pass - - LOG.info("Transfer data using netcat from %s to %s", src, dst) - cmd = f"nc -w {timeout}" - if d_protocol == "udp": - cmd += " -u" - receive_cmd = f"echo {check_string} | {cmd} -l {d_port} > {d_path}" - d_session.sendline(receive_cmd) - send_cmd = f"{cmd} {dst} {d_port} < {s_path}" - status, output = s_session.cmd_status_output( - send_cmd, timeout=file_transfer_timeout - ) - if status: - err = f"Fail to transfer file between {src} -> {dst}." - if check_string not in output: - err += ( - "src did not receive check " - f"string {check_string} sent by dst." + try: + s_session.cmd(f"iptables -I INPUT -p {d_protocol} -j ACCEPT") + d_session.cmd(f"iptables -I OUTPUT -p {d_protocol} -j ACCEPT") + except Exception: # pylint: disable=W0703 + pass + + LOG.info( + "Transfer data using netcat from %s to %s (attempt %d/%d)", + src, + dst, + attempt + 1, + tries, ) - err += f"send nc command {send_cmd}, output {output}" - err += f"Receive nc command {receive_cmd}." - raise NetcatTransferFailedError(status, err) - - if check_sum: - LOG.info("md5sum cmd = md5sum %s", s_path) - output = s_session.cmd(f"md5sum {s_path}") - src_md5 = output.split()[0] - dst_md5 = d_session.cmd(f"md5sum {d_path}").split()[0] - if src_md5.strip() != dst_md5.strip(): - err_msg = ( - "Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}" + cmd = f"nc -w {timeout}" + if d_protocol == "udp": + cmd += " -u" + receive_cmd = f"echo {check_string} | {cmd} -l {d_port} > {d_path}" + d_session.sendline(receive_cmd) + send_cmd = f"{cmd} {dst} {d_port} < {s_path}" + status, output = s_session.cmd_status_output( + send_cmd, timeout=file_transfer_timeout ) - raise NetcatTransferIntegrityError(err_msg) - return True + if status: + err = f"Fail to transfer file between {src} -> {dst}." + if check_string not in output: + err += ( + "src did not receive check " + f"string {check_string} sent by dst." + ) + err += f"send nc command {send_cmd}, output {output}" + err += f"Receive nc command {receive_cmd}." + raise NetcatTransferFailedError(status, err) + + if check_sum: + LOG.info("md5sum cmd = md5sum %s", s_path) + output = s_session.cmd(f"md5sum {s_path}") + src_md5 = output.split()[0] + dst_md5 = d_session.cmd(f"md5sum {d_path}").split()[0] + if src_md5.strip() != dst_md5.strip(): + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) + raise NetcatTransferIntegrityError(err_msg) + return True + except ( + NetcatTransferTimeoutError, + NetcatTransferFailedError, + UDPError, + ) as error: + if attempt < tries - 1: + LOG.debug( + "Transfer failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + # reset sessions for retry + if s_session: + s_session.close() + if d_session: + d_session.close() + s_session = None + d_session = None + else: + raise + return False def udp_copy_between_remotes( @@ -1201,6 +1234,7 @@ def udp_copy_between_remotes( c_prompt="\n", d_port="9000", timeout=600, + tries=1, ): """ Copy files from guest to guest using udp. @@ -1218,9 +1252,9 @@ def udp_copy_between_remotes( :param c_prompt: command line prompt of remote host(guest) :param d_port: the port data transfer :param timeout: data transfer timeout + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ - s_session = remote_login(c_type, src, s_port, s_name, s_passwd, c_prompt) - d_session = remote_login(c_type, dst, s_port, d_name, d_passwd, c_prompt) def get_abs_path(session, filename, extension): """Return file path drive+path.""" @@ -1302,23 +1336,43 @@ def stop_server(session): if server_alive(session): session.cmd_output_safe(stop_cmd) - try: - src_md5 = get_file_md5(s_session, s_path) - if not server_alive(s_session): - start_server(s_session) - start_client(d_session) - dst_md5 = get_file_md5(d_session, d_path) - if src_md5 != dst_md5: - err_msg = ( - "Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}" + for attempt in range(tries): + try: + s_session = remote_login( + c_type, src, s_port, s_name, s_passwd, c_prompt + ) + d_session = remote_login( + c_type, dst, s_port, d_name, d_passwd, c_prompt ) - raise UDPError(err_msg) - finally: - stop_server(s_session) - s_session.close() - d_session.close() + try: + src_md5 = get_file_md5(s_session, s_path) + if not server_alive(s_session): + start_server(s_session) + start_client(d_session) + dst_md5 = get_file_md5(d_session, d_path) + if src_md5 != dst_md5: + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) + raise UDPError(err_msg) + finally: + stop_server(s_session) + s_session.close() + d_session.close() + return # transfer is successful + except UDPError as error: + if attempt < tries - 1: + LOG.debug( + "UDP transfer failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise def login_from_session(