diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index a65f9a9960b4..9c1620603fd3 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -19,7 +19,10 @@ import argparse import logging import os.path +import queue import shlex +import threading +import time import typing import unittest import zipfile @@ -37,8 +40,10 @@ from apache_beam.options.pipeline_options import PortableOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TypeOptions +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.portability import portable_runner_test from apache_beam.runners.portability import prism_runner +from apache_beam.runners.worker import worker_pool_main from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms import trigger @@ -488,6 +493,63 @@ def test_singleton(self, enable_singleton): else: mock_prism_server.assert_called_once() + def test_loopback_worker_daemon_thread_accumulation(self): + """Verifies that in LOOPBACK mode, the external worker pool servicer properly + tracks active thread-based SdkHarness workers and cleanly shuts them down in + StopWorker via sentinel messages. This prevents background daemon threads from + accumulating across sequential pipeline executions and leaking resources. + """ + servicer = worker_pool_main.BeamFnExternalWorkerPoolServicer( + use_process=False, state_cache_size=0, data_buffer_time_limit_ms=0) + + active_workers = [] + mock_responses = queue.Queue() + + def mock_run(self_worker): + active_workers.append(self_worker) + mock_responses.get() + active_workers.remove(self_worker) + + def wait_for_workers(expected_count, timeout=5.0): + start = time.time() + while time.time() - start < timeout: + if len(active_workers) == expected_count: + return + time.sleep(0.01) + self.assertEqual(len(active_workers), expected_count) + + with mock.patch( + 'apache_beam.runners.worker.sdk_worker.SdkHarness') as mock_harness: + mock_harness.return_value._responses = mock_responses + mock_harness.return_value.run = lambda: mock_run(mock_harness) + + # Simulate starting Worker 1 for Pipeline 1 + req1 = beam_fn_api_pb2.StartWorkerRequest(worker_id="worker_1") + req1.control_endpoint.url = "localhost:12345" + servicer.StartWorker(req1, None) + + wait_for_workers(1) + + # Simulate stopping Worker 1 at the end of Pipeline 1 + stop_req1 = beam_fn_api_pb2.StopWorkerRequest(worker_id="worker_1") + servicer.StopWorker(stop_req1, None) + + # Verify the fix: StopWorker successfully tells the thread harness to shut down, + # completely resolving the thread leak! + wait_for_workers(0) + + # Simulate starting Worker 2 for Pipeline 2 + req2 = beam_fn_api_pb2.StartWorkerRequest(worker_id="worker_2") + req2.control_endpoint.url = "localhost:12345" + servicer.StartWorker(req2, None) + + wait_for_workers(1) + + # Clean up the second worker + servicer.StopWorker( + beam_fn_api_pb2.StopWorkerRequest(worker_id="worker_2"), None) + wait_for_workers(0) + if __name__ == '__main__': # Run the tests. diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py index 425a9fc57752..efe927b729c1 100644 --- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py +++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py @@ -45,6 +45,7 @@ from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.worker import sdk_worker from apache_beam.utils import thread_pool_executor +from apache_beam.utils.sentinel import Sentinel _LOGGER = logging.getLogger(__name__) @@ -82,6 +83,7 @@ def __init__( self._state_cache_size = state_cache_size self._data_buffer_time_limit_ms = data_buffer_time_limit_ms self._worker_processes: dict[str, subprocess.Popen] = {} + self._worker_threads: dict[str, sdk_worker.SdkHarness] = {} @classmethod def start( @@ -166,6 +168,7 @@ def StartWorker( worker_id=start_worker_request.worker_id, state_cache_size=self._state_cache_size, data_buffer_time_limit_ms=self._data_buffer_time_limit_ms) + self._worker_threads[start_worker_request.worker_id] = worker worker_thread = threading.Thread( name='run_worker_%s' % start_worker_request.worker_id, target=worker.run) @@ -188,6 +191,14 @@ def StopWorker( _LOGGER.info("Stopping worker %s" % stop_worker_request.worker_id) kill_process_gracefully(worker_process) + # applicable for thread mode to ensure thread cleanup by + # unblocking the harness request stream. + worker_thread_harness = self._worker_threads.pop( + stop_worker_request.worker_id, None) + if worker_thread_harness: + _LOGGER.info("Stopping thread worker %s" % stop_worker_request.worker_id) + worker_thread_harness._responses.put(Sentinel.sentinel) + return beam_fn_api_pb2.StopWorkerResponse()