Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions sdks/python/apache_beam/runners/worker/data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.utils.byte_limited_queue import ByteLimitedQueue

if TYPE_CHECKING:
import apache_beam.coders.slow_stream
Expand Down Expand Up @@ -455,10 +456,20 @@ class _GrpcDataChannel(DataChannel):

def __init__(self, data_buffer_time_limit_ms=0):
# type: (int) -> None
def _element_weight(element):
if isinstance(element, beam_fn_api_pb2.Elements.Data):
return len(element.data)
elif isinstance(element, beam_fn_api_pb2.Elements.Timers):
return len(element.timers)
return 0

self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers]
self._to_send = ByteLimitedQueue(
maxsize=10000, maxweight=100 << 20,
weighing_fn=_element_weight) # type: queue.Queue[DataOrTimers]
self._received = collections.defaultdict(
lambda: queue.Queue(maxsize=5)
lambda: ByteLimitedQueue(
maxsize=5, maxweight=100 << 20, weighing_fn=_element_weight)
) # type: DefaultDict[str, queue.Queue[DataOrTimers]]
Comment thread
scwhittle marked this conversation as resolved.
Outdated

# Keep a cache of completed instructions. Data for completed instructions
Expand Down
95 changes: 95 additions & 0 deletions sdks/python/apache_beam/utils/byte_limited_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""A thread-safe queue that limits capacity by total byte size."""

import queue
import time
from typing import Any
from typing import Callable


class ByteLimitedQueue(queue.Queue):
"""A queue.Queue that limits by both element count and total weight.

A single element is allowed to exceed the maxweight to avoid deadlock.
"""
def __init__(
self,
weighing_fn, # type: Callable[[Any], int]
maxsize=0, # type: int
maxweight=0, # type: int
):
# type: (...) -> None

"""Initializes a ByteLimitedQueue.

Args:
weighing_fn: A Callable that accepts an item and returns its integer
weight.
maxsize: The maximum number of items allowed in the queue. If 0 or
negative, there is no limit on the number of elements.
maxweight: The maximum accumulated weight allowed in the queue.
"""
super().__init__(maxsize=0)
self.max_elements = maxsize
self.max_weight = maxweight
self.weighing_fn = weighing_fn
self._byte_size = 0

def _is_full(self, item_size):
if self._qsize() == 0:
return False
if self.max_elements > 0 and self._qsize() >= self.max_elements:
return True
if self.max_weight > 0 and self._byte_size + item_size > self.max_weight:
return True
return False

def put(self, item, block=True, timeout=None):
Copy link
Copy Markdown
Contributor

@tvalentyn tvalentyn May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth calling out in the docstring that we don't guarantee that the element will land as soon as enough space opens up, since https://github.com/python/cpython/blob/45c47d26c230086163ac1ef0aa9f955f794fb69c/Lib/queue.py#L214-L215 will wake up one random thread that is waiting, which might not be the one that can fit. this is fine as long as we are continuously emptying the queue

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the queue to be fair as I was concerned that if we are limiting we could end up starving a bundle producing large items if there was a bundle producing small items continuously.

item_size = max(1, self.weighing_fn(item))
with self.not_full:
if not block:
if self._is_full(item_size):
raise queue.Full
elif timeout is None:
while self._is_full(item_size):
self.not_full.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
endtime = time.time() + timeout
while self._is_full(item_size):
remaining = endtime - time.time()
if remaining <= 0.0:
raise queue.Full
self.not_full.wait(remaining)

self._put((item, item_size))
self._byte_size += item_size
self.unfinished_tasks += 1
self.not_empty.notify()
Comment thread
scwhittle marked this conversation as resolved.
Outdated

def _get(self):
item, item_weight = super()._get()
self._byte_size -= item_weight
return item
Comment thread
scwhittle marked this conversation as resolved.
Outdated

def byte_size(self):
"""Return the total byte weight of elements in the queue."""
with self.mutex:
return self._byte_size
168 changes: 168 additions & 0 deletions sdks/python/apache_beam/utils/byte_limited_queue_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for byte-limited queue."""

import queue
import sys
import threading
import time
import unittest

from apache_beam.utils.byte_limited_queue import ByteLimitedQueue


class FakeItem(object):
def __init__(self, size):
self._size = size

def weight(self):
return self._size


class ByteLimitedQueueTest(unittest.TestCase):
def test_unbounded(self):
bq = ByteLimitedQueue(lambda x: x.weight())
for i in range(200):
bq.put(FakeItem(i))
# Add 1 since weight of zero is set to 1
self.assertEqual(bq.byte_size(), sum(range(200)) + 1)
self.assertEqual(bq.qsize(), 200)

def test_put_and_get(self):
bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=200)
bq.put(FakeItem(50))
bq.put(FakeItem(140))
self.assertEqual(bq.byte_size(), 190)
self.assertEqual(bq.qsize(), 2)
# Putting another would exceed 200.
with self.assertRaises(queue.Full):
bq.put(FakeItem(20), block=False)
bq.put(FakeItem(10), block=False)
self.assertEqual(bq.byte_size(), 200)
self.assertEqual(bq.qsize(), 3)

self.assertEqual(bq.get().weight(), 50)
self.assertEqual(bq.byte_size(), 150)
self.assertEqual(bq.qsize(), 2)
bq.put(FakeItem(20), block=False)

def test_dual_limit(self):
# Queue limits: at most 2 items, OR at most 100 weight.
bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=3, maxweight=100)
bq.put(FakeItem(30))
bq.put(FakeItem(40))
bq.put(FakeItem(20))
self.assertEqual(bq.byte_size(), 90)
self.assertEqual(bq.qsize(), 3)
# Full on element count (size=2).
with self.assertRaises(queue.Full):
bq.put(FakeItem(10), block=False)
self.assertEqual(bq.get().weight(), 30)
self.assertEqual(bq.get().weight(), 40)
bq.put(FakeItem(10))
# Full on byte count
with self.assertRaises(queue.Full):
bq.put(FakeItem(90), block=False)
self.assertEqual(bq.get().weight(), 20)
bq.put(FakeItem(90), block=False)

@unittest.skipIf(sys.version_info < (3, 13), 'Queue.ShutDown added in 3.13.')
def test_multithreading(self):
bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=100)
received = []

def producer():
for i in range(101):
bq.put(FakeItem(i))

def consumer():
while True:
try:
received.append(bq.get().weight())
except queue.ShutDown:
break

t1 = threading.Thread(target=producer)
t2 = threading.Thread(target=producer)
t3 = threading.Thread(target=consumer)

t1.start()
t2.start()
t3.start()

t1.join()
t2.join()
bq.shutdown()

t3.join()

self.assertEqual(len(received), 202)
self.assertEqual(sum(received), 2 * sum(range(101)))

def test_multithreading_timeout(self):
bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=10)
bq.put(FakeItem(10))

# The queue is completely full. A timeout put should raise queue.Full.
with self.assertRaises(queue.Full):
bq.put(FakeItem(5), timeout=0.01)

def delayed_consumer():
time.sleep(0.05)
bq.get()

# Start a thread that will free up space after 50ms.
t = threading.Thread(target=delayed_consumer)
t.start()

# The put should succeed once the consumer runs, use a high timeout to
# flakiness.
bq.put(FakeItem(5), timeout=60)
Comment thread
scwhittle marked this conversation as resolved.
Outdated
t.join()

def test_negative_timeout(self):
bq = ByteLimitedQueue(lambda x: x.weight())
# Putting an item with a negative timeout should raise ValueError.
with self.assertRaises(ValueError):
bq.put(FakeItem(5), timeout=-1)

def test_single_element_override(self):
bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=10)
# An item of size 50 exceeds maxweight 10, but should be admitted
# immediately without blocking since the queue is currently empty!
bq.put(FakeItem(50), block=False)
self.assertEqual(bq.qsize(), 1)
self.assertEqual(bq.byte_size(), 50)

def test_inconsistent_weighing_fn(self):
# Return a different weight for the same item.
weights = [10, 5]
bq = ByteLimitedQueue(lambda x: weights.pop(0), maxweight=100)

bq.put(1)
Comment thread
scwhittle marked this conversation as resolved.
Outdated
self.assertEqual(bq.byte_size(), 10)

# Upon popping, the weighing function (if called) would have returned 5,
# but the stored weight prevents corruption and cleanly reduces the size to
# 0.
bq.get()
self.assertEqual(bq.byte_size(), 0)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could also have a test for get with a timeout

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call, done

if __name__ == '__main__':
unittest.main()
Loading