diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index c4758984af83..9e5034b58c00 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -240,7 +240,7 @@ progress: sr, err := b.Split(ctx, wk, 0.5 /* fraction of remainder */, nil /* allowed splits */) if err != nil { slog.Warn("SDK Error from split, aborting splits and failing bundle", "bundle", rb, "error", err.Error()) - if b.BundleErr != nil { + if b.BundleErr == nil { b.BundleErr = err } return b.BundleErr diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index f849ab11e77c..67759a1b1b72 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -95,11 +95,8 @@ def generate_proxy(element_type: type) -> pd.DataFrame: else: fields = named_fields_from_element_type(element_type) proxy = pd.DataFrame(columns=[name for name, _ in fields]) - for name, typehint in fields: - dtype = dtype_from_typehint(typehint) - proxy[name] = proxy[name].astype(dtype) - - return proxy + dtypes = {name: dtype_from_typehint(typehint) for name, typehint in fields} + return proxy.astype(dtypes) def element_type_from_dataframe( diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index a65f9a9960b4..4735950d77aa 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -295,6 +295,98 @@ def test_after_count_trigger_streaming(self): ('B-3', {10, 15, 16}), ]))) + def test_sdf_split_exception(self): + from apache_beam.io.iobase import RestrictionTracker + + class SimpleTracker(RestrictionTracker): + def __init__(self, rest): + self._rest = rest + def current_restriction(self): + return self._rest + def try_claim(self, position): + return True + def check_done(self): + pass + def is_bounded(self): + return True + + class FailingSplitProvider(beam.RestrictionProvider): + def initial_restriction(self, element): + return (0, 10) + def create_tracker(self, restriction): + return SimpleTracker(restriction) + def restriction_size(self, element, restriction): + return 10 + def split_and_size(self, element, restriction): + raise RuntimeError("400 invalid split") + + class SplittableFn(beam.DoFn): + def process(self, element, restriction=beam.DoFn.RestrictionParam(FailingSplitProvider())): + yield element + + try: + with self.create_pipeline() as p: + _ = p | beam.Create([1]) | beam.ParDo(SplittableFn()) + except Exception as e: + print("\n[ACTUAL EXCEPTION RAISED IN STATIC SPLIT]:\n%s" % e) + self.assertRegex(str(e), "invalid split") + else: + self.fail("Exception not raised") + + def test_sdf_dynamic_split_exception(self): + from apache_beam.io.iobase import RestrictionTracker + from apache_beam.io.iobase import RestrictionProgress + import time + + class DynamicSplitTracker(RestrictionTracker): + def __init__(self, rest): + self._rest = rest + + def current_restriction(self): + return self._rest + + def current_progress(self): + return RestrictionProgress(fraction=0.5) + + def try_claim(self, position): + return True + + def check_done(self): + pass + + def is_bounded(self): + return True + + def try_split(self, fraction_of_remainder): + # Raised when the runner sends a dynamic runtime splitting request + raise RuntimeError("dynamic runtime split failed") + + class DynamicSplitProvider(beam.RestrictionProvider): + def initial_restriction(self, element): + return (0, 100) + + def create_tracker(self, restriction): + return DynamicSplitTracker(restriction) + + def restriction_size(self, element, restriction): + return 100 + + class SleepingSDF(beam.DoFn): + def process(self, element, restriction=beam.DoFn.RestrictionParam(DynamicSplitProvider())): + # Sleep enough to guarantee that Prism sends a dynamic split request due to slow progress + for i in range(10): + time.sleep(0.5) + yield element + i + + try: + with self.create_pipeline() as p: + _ = p | beam.Create([1]) | beam.ParDo(SleepingSDF()) + except Exception as e: + print("\n[ACTUAL EXCEPTION RAISED IN DYNAMIC SPLIT]:\n%s" % e) + self.assertRegex(str(e), "dynamic runtime split failed") + else: + self.fail("Exception not raised") + class PrismJobServerTest(unittest.TestCase): def setUp(self) -> None: diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 28568bd893c5..ad3d5bc66469 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -153,21 +153,39 @@ def _run_event_loop(): @staticmethod def reset_state(): + event_loop_thread_to_join = None with AsyncWrapper._lock: if AsyncWrapper._event_loop: AsyncWrapper._event_loop.call_soon_threadsafe( AsyncWrapper._event_loop.stop) if AsyncWrapper._event_loop_thread: - AsyncWrapper._event_loop_thread.join() + event_loop_thread_to_join = AsyncWrapper._event_loop_thread AsyncWrapper._event_loop = None AsyncWrapper._event_loop_thread = None if AsyncWrapper._loop_started is not None: AsyncWrapper._loop_started.clear() - for pool in AsyncWrapper._pool.values(): - pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown( - wait=True, cancel_futures=True) + pools = list(AsyncWrapper._pool.values()) + + # We must join the asyncio event loop thread outside of the lock block. + # If joined inside the lock, the waiting thread holds the lock while blocking, + # preventing active coroutines' done callbacks from acquiring the lock on the + # event loop thread, resulting in a deadlock. + if event_loop_thread_to_join: + event_loop_thread_to_join.join() + + # We must acquire and shut down the thread pools outside of the lock block. + # If shutdown(wait=True) is called inside the lock, the caller blocks holding + # the lock, preventing active worker threads from acquiring the lock to run + # their done callbacks, resulting in a deadlock. + pools_to_shutdown = [ + pool.acquire(AsyncWrapper.initialize_pool(1)) for pool in pools + ] + + for pool in pools_to_shutdown: + pool.shutdown(wait=True, cancel_futures=True) + with AsyncWrapper._lock: AsyncWrapper._pool = {} AsyncWrapper._processing_elements = {} @@ -268,7 +286,8 @@ async def _collect(result): def decrement_items_in_buffer(self, future): with AsyncWrapper._lock: - AsyncWrapper._items_in_buffer[self._uuid] -= 1 + if self._uuid in AsyncWrapper._items_in_buffer: + AsyncWrapper._items_in_buffer[self._uuid] -= 1 def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): """Schedules an item to be processed asynchronously if there is room. diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 81c7b8e163ff..39901d791fb9 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -16,6 +16,7 @@ # import logging +import multiprocessing import random import time import unittest @@ -487,6 +488,40 @@ def add_item(i): self.check_output(results[i], expected_outputs['key' + str(i)]) self.assertEqual(bag_states['key' + str(i)].items, []) + @staticmethod + def _run_reset_state_concurrent_teardown(use_asyncio): + dofn = BasicDofn(sleep_time=0.5) + async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=use_asyncio) + async_dofn.setup() + fake_bag_state = FakeBagState([]) + fake_timer = FakeTimer(0) + + # Start processing an item. This starts a worker thread/coroutine sleeping for 0.5s. + async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer) + time.sleep(0.05) + + # Verify that calling reset_state() while background tasks are actively running + # completes cleanly without causing lock-ordering deadlocks. + async_lib.AsyncWrapper.reset_state() + + def test_reset_state_concurrent_teardown(self): + # Verify concurrent teardown safety in a separate process to prevent any potential + # regressions from freezing the main pytest process at exit. + p = multiprocessing.Process( + target=AsyncTest._run_reset_state_concurrent_teardown, + args=(self.use_asyncio, )) + p.start() + p.join(timeout=10.0) + + if p.is_alive(): + p.terminate() + p.join() + self.fail( + "reset_state() deadlocked/hung waiting for active threads/tasks to finish" + ) + else: + self.assertEqual(p.exitcode, 0) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/stats_test.py b/sdks/python/apache_beam/transforms/stats_test.py index bf634c003a07..b236c7e3d5ac 100644 --- a/sdks/python/apache_beam/transforms/stats_test.py +++ b/sdks/python/apache_beam/transforms/stats_test.py @@ -264,6 +264,10 @@ def _approx_quantile_generator(size, num_of_quantiles, absoluteError): quantiles.append(size - 1) return quantiles + @staticmethod + def _sum_and_second(x): + return (sum(x), x[1]) + def test_quantiles_globaly(self): with TestPipeline() as p: pc = p | Create(list(range(101))) @@ -490,22 +494,32 @@ def test_batched_quantiles(self): 3, input_batched=True)) with_key = ( pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally( - 3, key=sum, input_batched=True)) + 3, + key=ApproximateQuantilesTest._sum_and_second, + input_batched=True)) key_with_reversed = ( pc | 'Globally with key and reversed' >> beam.ApproximateQuantiles.Globally( - 3, key=sum, reverse=True, input_batched=True)) + 3, + key=ApproximateQuantilesTest._sum_and_second, + reverse=True, + input_batched=True)) assert_that( globally, equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]), label='checkGlobally') + # When key is present, both (72.5, 225) and (22.5, 275) produce the exact same + # sum (297.5). If we just use key=sum, tie-breaking is sensitive to bundle merging + # order and shared class-level jitter state, leading to flaky test failures. + # With the secondary key (defined in _sum_and_second), we can break ties + # deterministically. assert_that( with_key, equal_to([[(50.0, 0), (72.5, 225), (99.9, 499)]]), label='checkGloballyWithKey') assert_that( key_with_reversed, - equal_to([[(99.9, 499), (72.5, 225), (50.0, 0)]]), + equal_to([[(99.9, 499), (22.5, 275), (50.0, 0)]]), label='checkGloballyWithKeyAndReversed') def test_batched_weighted_quantiles(self): diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index 45ae27baffe7..8158b4443e1a 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -223,8 +223,11 @@ def _get_series(self, batch: pd.DataFrame): def produce_batch(self, elements): batch = pd.DataFrame.from_records(elements, columns=self._columns) - for column, typehint in self._element_type._fields: - batch[column] = batch[column].astype(dtype_from_typehint(typehint)) + dtypes = { + column: dtype_from_typehint(typehint) + for column, typehint in self._element_type._fields + } + batch = batch.astype(dtypes) return batch @@ -249,8 +252,11 @@ def produce_batch(self, elements): # Note from_records has an index= parameter batch = pd.DataFrame.from_records(elements, columns=self._columns) - for column, typehint in self._element_type._fields: - batch[column] = batch[column].astype(dtype_from_typehint(typehint)) + dtypes = { + column: dtype_from_typehint(typehint) + for column, typehint in self._element_type._fields + } + batch = batch.astype(dtypes) return batch.set_index(self._index_columns)