Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions sdks/python/apache_beam/runners/portability/prism_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import argparse
import logging
import os.path
import queue
import shlex
import threading
import time
import typing
import unittest
import zipfile
Expand All @@ -37,8 +40,10 @@
from apache_beam.options.pipeline_options import PortableOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.portability import portable_runner_test
from apache_beam.runners.portability import prism_runner
from apache_beam.runners.worker import worker_pool_main
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import trigger
Expand Down Expand Up @@ -488,6 +493,63 @@ def test_singleton(self, enable_singleton):
else:
mock_prism_server.assert_called_once()

def test_loopback_worker_daemon_thread_accumulation(self):
"""Verifies that in LOOPBACK mode, the external worker pool servicer properly
tracks active thread-based SdkHarness workers and cleanly shuts them down in
StopWorker via sentinel messages. This prevents background daemon threads from
accumulating across sequential pipeline executions and leaking resources.
"""
servicer = worker_pool_main.BeamFnExternalWorkerPoolServicer(
use_process=False, state_cache_size=0, data_buffer_time_limit_ms=0)

active_workers = []
mock_responses = queue.Queue()

def mock_run(self_worker):
active_workers.append(self_worker)
mock_responses.get()
active_workers.remove(self_worker)

def wait_for_workers(expected_count, timeout=5.0):
start = time.time()
while time.time() - start < timeout:
if len(active_workers) == expected_count:
return
time.sleep(0.01)
self.assertEqual(len(active_workers), expected_count)

with mock.patch(
'apache_beam.runners.worker.sdk_worker.SdkHarness') as mock_harness:
mock_harness.return_value._responses = mock_responses
mock_harness.return_value.run = lambda: mock_run(mock_harness)

# Simulate starting Worker 1 for Pipeline 1
req1 = beam_fn_api_pb2.StartWorkerRequest(worker_id="worker_1")
req1.control_endpoint.url = "localhost:12345"
servicer.StartWorker(req1, None)

wait_for_workers(1)

# Simulate stopping Worker 1 at the end of Pipeline 1
stop_req1 = beam_fn_api_pb2.StopWorkerRequest(worker_id="worker_1")
servicer.StopWorker(stop_req1, None)

# Verify the fix: StopWorker successfully tells the thread harness to shut down,
# completely resolving the thread leak!
wait_for_workers(0)

# Simulate starting Worker 2 for Pipeline 2
req2 = beam_fn_api_pb2.StartWorkerRequest(worker_id="worker_2")
req2.control_endpoint.url = "localhost:12345"
servicer.StartWorker(req2, None)

wait_for_workers(1)

# Clean up the second worker
servicer.StopWorker(
beam_fn_api_pb2.StopWorkerRequest(worker_id="worker_2"), None)
wait_for_workers(0)


if __name__ == '__main__':
# Run the tests.
Expand Down
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/runners/worker/worker_pool_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker import sdk_worker
from apache_beam.utils import thread_pool_executor
from apache_beam.utils.sentinel import Sentinel

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
self._state_cache_size = state_cache_size
self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
self._worker_processes: dict[str, subprocess.Popen] = {}
self._worker_threads: dict[str, sdk_worker.SdkHarness] = {}

@classmethod
def start(
Expand Down Expand Up @@ -166,6 +168,7 @@ def StartWorker(
worker_id=start_worker_request.worker_id,
state_cache_size=self._state_cache_size,
data_buffer_time_limit_ms=self._data_buffer_time_limit_ms)
self._worker_threads[start_worker_request.worker_id] = worker
Comment thread
shunping marked this conversation as resolved.
worker_thread = threading.Thread(
name='run_worker_%s' % start_worker_request.worker_id,
target=worker.run)
Expand All @@ -188,6 +191,14 @@ def StopWorker(
_LOGGER.info("Stopping worker %s" % stop_worker_request.worker_id)
kill_process_gracefully(worker_process)

# applicable for thread mode to ensure thread cleanup by
# unblocking the harness request stream.
worker_thread_harness = self._worker_threads.pop(
stop_worker_request.worker_id, None)
if worker_thread_harness:
_LOGGER.info("Stopping thread worker %s" % stop_worker_request.worker_id)
worker_thread_harness._responses.put(Sentinel.sentinel)

return beam_fn_api_pb2.StopWorkerResponse()


Expand Down
Loading