Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ progress:
sr, err := b.Split(ctx, wk, 0.5 /* fraction of remainder */, nil /* allowed splits */)
if err != nil {
slog.Warn("SDK Error from split, aborting splits and failing bundle", "bundle", rb, "error", err.Error())
if b.BundleErr != nil {
if b.BundleErr == nil {
b.BundleErr = err
}
return b.BundleErr
Expand Down
7 changes: 2 additions & 5 deletions sdks/python/apache_beam/dataframe/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,8 @@ def generate_proxy(element_type: type) -> pd.DataFrame:
else:
fields = named_fields_from_element_type(element_type)
proxy = pd.DataFrame(columns=[name for name, _ in fields])
for name, typehint in fields:
dtype = dtype_from_typehint(typehint)
proxy[name] = proxy[name].astype(dtype)

return proxy
dtypes = {name: dtype_from_typehint(typehint) for name, typehint in fields}
return proxy.astype(dtypes)


def element_type_from_dataframe(
Expand Down
92 changes: 92 additions & 0 deletions sdks/python/apache_beam/runners/portability/prism_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,98 @@ def test_after_count_trigger_streaming(self):
('B-3', {10, 15, 16}),
])))

def test_sdf_split_exception(self):
from apache_beam.io.iobase import RestrictionTracker

class SimpleTracker(RestrictionTracker):
def __init__(self, rest):
self._rest = rest
def current_restriction(self):
return self._rest
def try_claim(self, position):
return True
def check_done(self):
pass
def is_bounded(self):
return True

class FailingSplitProvider(beam.RestrictionProvider):
def initial_restriction(self, element):
return (0, 10)
def create_tracker(self, restriction):
return SimpleTracker(restriction)
def restriction_size(self, element, restriction):
return 10
def split_and_size(self, element, restriction):
raise RuntimeError("400 invalid split")

class SplittableFn(beam.DoFn):
def process(self, element, restriction=beam.DoFn.RestrictionParam(FailingSplitProvider())):
yield element

try:
with self.create_pipeline() as p:
_ = p | beam.Create([1]) | beam.ParDo(SplittableFn())
except Exception as e:
print("\n[ACTUAL EXCEPTION RAISED IN STATIC SPLIT]:\n%s" % e)
self.assertRegex(str(e), "invalid split")
else:
self.fail("Exception not raised")

def test_sdf_dynamic_split_exception(self):
from apache_beam.io.iobase import RestrictionTracker
from apache_beam.io.iobase import RestrictionProgress
import time

class DynamicSplitTracker(RestrictionTracker):
def __init__(self, rest):
self._rest = rest

def current_restriction(self):
return self._rest

def current_progress(self):
return RestrictionProgress(fraction=0.5)

def try_claim(self, position):
return True

def check_done(self):
pass

def is_bounded(self):
return True

def try_split(self, fraction_of_remainder):
# Raised when the runner sends a dynamic runtime splitting request
raise RuntimeError("dynamic runtime split failed")

class DynamicSplitProvider(beam.RestrictionProvider):
def initial_restriction(self, element):
return (0, 100)

def create_tracker(self, restriction):
return DynamicSplitTracker(restriction)

def restriction_size(self, element, restriction):
return 100

class SleepingSDF(beam.DoFn):
def process(self, element, restriction=beam.DoFn.RestrictionParam(DynamicSplitProvider())):
# Sleep enough to guarantee that Prism sends a dynamic split request due to slow progress
for i in range(10):
time.sleep(0.5)
yield element + i

try:
with self.create_pipeline() as p:
_ = p | beam.Create([1]) | beam.ParDo(SleepingSDF())
except Exception as e:
print("\n[ACTUAL EXCEPTION RAISED IN DYNAMIC SPLIT]:\n%s" % e)
self.assertRegex(str(e), "dynamic runtime split failed")
else:
self.fail("Exception not raised")


class PrismJobServerTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
29 changes: 24 additions & 5 deletions sdks/python/apache_beam/transforms/async_dofn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,39 @@ def _run_event_loop():

@staticmethod
def reset_state():
event_loop_thread_to_join = None
with AsyncWrapper._lock:
if AsyncWrapper._event_loop:
AsyncWrapper._event_loop.call_soon_threadsafe(
AsyncWrapper._event_loop.stop)
if AsyncWrapper._event_loop_thread:
AsyncWrapper._event_loop_thread.join()
event_loop_thread_to_join = AsyncWrapper._event_loop_thread

AsyncWrapper._event_loop = None
AsyncWrapper._event_loop_thread = None
if AsyncWrapper._loop_started is not None:
AsyncWrapper._loop_started.clear()

for pool in AsyncWrapper._pool.values():
pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown(
wait=True, cancel_futures=True)
pools = list(AsyncWrapper._pool.values())

# We must join the asyncio event loop thread outside of the lock block.
# If joined inside the lock, the waiting thread holds the lock while blocking,
# preventing active coroutines' done callbacks from acquiring the lock on the
# event loop thread, resulting in a deadlock.
if event_loop_thread_to_join:
event_loop_thread_to_join.join()

# We must acquire and shut down the thread pools outside of the lock block.
# If shutdown(wait=True) is called inside the lock, the caller blocks holding
# the lock, preventing active worker threads from acquiring the lock to run
# their done callbacks, resulting in a deadlock.
pools_to_shutdown = [
pool.acquire(AsyncWrapper.initialize_pool(1)) for pool in pools
]

for pool in pools_to_shutdown:
pool.shutdown(wait=True, cancel_futures=True)

with AsyncWrapper._lock:
AsyncWrapper._pool = {}
AsyncWrapper._processing_elements = {}
Expand Down Expand Up @@ -268,7 +286,8 @@ async def _collect(result):

def decrement_items_in_buffer(self, future):
with AsyncWrapper._lock:
AsyncWrapper._items_in_buffer[self._uuid] -= 1
if self._uuid in AsyncWrapper._items_in_buffer:
AsyncWrapper._items_in_buffer[self._uuid] -= 1

def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
"""Schedules an item to be processed asynchronously if there is room.
Expand Down
35 changes: 35 additions & 0 deletions sdks/python/apache_beam/transforms/async_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import logging
import multiprocessing
import random
import time
import unittest
Expand Down Expand Up @@ -487,6 +488,40 @@ def add_item(i):
self.check_output(results[i], expected_outputs['key' + str(i)])
self.assertEqual(bag_states['key' + str(i)].items, [])

@staticmethod
def _run_reset_state_concurrent_teardown(use_asyncio):
dofn = BasicDofn(sleep_time=0.5)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)

# Start processing an item. This starts a worker thread/coroutine sleeping for 0.5s.
async_dofn.process(('key1', 1), to_process=fake_bag_state, timer=fake_timer)
time.sleep(0.05)

# Verify that calling reset_state() while background tasks are actively running
# completes cleanly without causing lock-ordering deadlocks.
async_lib.AsyncWrapper.reset_state()

def test_reset_state_concurrent_teardown(self):
# Verify concurrent teardown safety in a separate process to prevent any potential
# regressions from freezing the main pytest process at exit.
p = multiprocessing.Process(
target=AsyncTest._run_reset_state_concurrent_teardown,
args=(self.use_asyncio, ))
p.start()
p.join(timeout=10.0)

if p.is_alive():
p.terminate()
p.join()
self.fail(
"reset_state() deadlocked/hung waiting for active threads/tasks to finish"
)
else:
self.assertEqual(p.exitcode, 0)


if __name__ == '__main__':
unittest.main()
20 changes: 17 additions & 3 deletions sdks/python/apache_beam/transforms/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def _approx_quantile_generator(size, num_of_quantiles, absoluteError):
quantiles.append(size - 1)
return quantiles

@staticmethod
def _sum_and_second(x):
return (sum(x), x[1])

def test_quantiles_globaly(self):
with TestPipeline() as p:
pc = p | Create(list(range(101)))
Expand Down Expand Up @@ -490,22 +494,32 @@ def test_batched_quantiles(self):
3, input_batched=True))
with_key = (
pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally(
3, key=sum, input_batched=True))
3,
key=ApproximateQuantilesTest._sum_and_second,
input_batched=True))
key_with_reversed = (
pc | 'Globally with key and reversed' >>
beam.ApproximateQuantiles.Globally(
3, key=sum, reverse=True, input_batched=True))
3,
key=ApproximateQuantilesTest._sum_and_second,
reverse=True,
input_batched=True))
assert_that(
globally,
equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]),
label='checkGlobally')
# When key is present, both (72.5, 225) and (22.5, 275) produce the exact same
# sum (297.5). If we just use key=sum, tie-breaking is sensitive to bundle merging
# order and shared class-level jitter state, leading to flaky test failures.
# With the secondary key (defined in _sum_and_second), we can break ties
# deterministically.
assert_that(
with_key,
equal_to([[(50.0, 0), (72.5, 225), (99.9, 499)]]),
label='checkGloballyWithKey')
assert_that(
key_with_reversed,
equal_to([[(99.9, 499), (72.5, 225), (50.0, 0)]]),
equal_to([[(99.9, 499), (22.5, 275), (50.0, 0)]]),
label='checkGloballyWithKeyAndReversed')

def test_batched_weighted_quantiles(self):
Expand Down
14 changes: 10 additions & 4 deletions sdks/python/apache_beam/typehints/pandas_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,11 @@ def _get_series(self, batch: pd.DataFrame):
def produce_batch(self, elements):
batch = pd.DataFrame.from_records(elements, columns=self._columns)

for column, typehint in self._element_type._fields:
batch[column] = batch[column].astype(dtype_from_typehint(typehint))
dtypes = {
column: dtype_from_typehint(typehint)
for column, typehint in self._element_type._fields
}
batch = batch.astype(dtypes)

return batch

Expand All @@ -249,8 +252,11 @@ def produce_batch(self, elements):
# Note from_records has an index= parameter
batch = pd.DataFrame.from_records(elements, columns=self._columns)

for column, typehint in self._element_type._fields:
batch[column] = batch[column].astype(dtype_from_typehint(typehint))
dtypes = {
column: dtype_from_typehint(typehint)
for column, typehint in self._element_type._fields
}
batch = batch.astype(dtypes)

return batch.set_index(self._index_columns)

Expand Down
Loading