diff --git a/Cargo.lock b/Cargo.lock index 012573deb452d..331683ebbe15c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2090,6 +2090,7 @@ dependencies = [ "parquet", "rand 0.9.4", "tempfile", + "tokio", "url", ] diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index 06c84d8acb493..482c9fff17a1b 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -66,6 +66,7 @@ parking_lot = { workspace = true } parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true } +tokio = { workspace = true, features = ["time"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 829e313d2381e..b04fd95d410a1 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -20,10 +20,13 @@ use datafusion_common::{Result, internal_datafusion_err}; use std::fmt::Display; +use std::future::Future; use std::hash::{Hash, Hasher}; +use std::pin::Pin; use std::{cmp::Ordering, sync::Arc, sync::atomic}; mod pool; +mod reclaimer; #[cfg(feature = "arrow_buffer_pool")] pub mod arrow; @@ -36,6 +39,7 @@ pub use datafusion_common::{ human_readable_count, human_readable_duration, human_readable_size, units, }; pub use pool::*; +pub use reclaimer::{MemoryReclaimer, reclaimer_state}; /// Tracks and potentially limits memory use across operators during execution. /// @@ -209,6 +213,17 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug + Display { /// On error the `allocation` will not be increased in size fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>; + /// Async variant of [`Self::try_grow`]. Default delegates to the + /// sync version; reclaim-aware pools (e.g. [`TrackConsumersPool`]) + /// override to invoke registered [`MemoryReclaimer`]s on OOM. + fn try_grow_async<'a>( + &'a self, + reservation: &'a MemoryReservation, + additional: usize, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { self.try_grow(reservation, additional) }) + } + /// Return the total amount of memory reserved fn reserved(&self) -> usize; @@ -249,6 +264,9 @@ pub struct MemoryConsumer { name: String, can_spill: bool, id: usize, + /// Reclaimer collected by reclaim-aware pools at register time. Not + /// part of consumer identity (excluded from `Eq`/`Hash`). + reclaimer: Option>, } impl PartialEq for MemoryConsumer { @@ -287,20 +305,39 @@ impl MemoryConsumer { name: name.into(), can_spill: false, id: Self::new_unique_id(), + reclaimer: None, } } - /// Returns a clone of this [`MemoryConsumer`] with a new unique id, - /// which can be registered with a [`MemoryPool`], - /// This new consumer is separate from the original. + /// Clone this [`MemoryConsumer`] with a new unique id. + /// + /// Drops any attached reclaimer: it is bound to the original operator's + /// state and would target the wrong owner under a new id (and bypass + /// the id-keyed requestor-self-skip in `try_grow_async`). pub fn clone_with_new_id(&self) -> Self { Self { name: self.name.clone(), can_spill: self.can_spill, id: Self::new_unique_id(), + reclaimer: None, + } + } + + /// Attach a [`MemoryReclaimer`] and mark this consumer spill-capable. + /// Pools without reclaim support ignore the reclaimer. + pub fn with_reclaimer(self, reclaimer: Arc) -> Self { + Self { + can_spill: true, + reclaimer: Some(reclaimer), + ..self } } + /// Returns the attached [`MemoryReclaimer`], if any. + pub fn reclaimer(&self) -> Option<&Arc> { + self.reclaimer.as_ref() + } + /// Return the unique id of this [`MemoryConsumer`] pub fn id(&self) -> usize { self.id @@ -461,6 +498,17 @@ impl MemoryReservation { Ok(()) } + /// Async variant of [`Self::try_grow`]. On a reclaim-aware pool, + /// triggers registered [`MemoryReclaimer`]s before surfacing OOM. + pub async fn try_grow_async(&self, capacity: usize) -> Result<()> { + self.registration + .pool + .try_grow_async(self, capacity) + .await?; + self.size.fetch_add(capacity, atomic::Ordering::Relaxed); + Ok(()) + } + /// Splits off `capacity` bytes from this [`MemoryReservation`] /// into a new [`MemoryReservation`] with the same /// [`MemoryConsumer`]. diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index aac95b9d6a81f..b55fa9bb14291 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -15,19 +15,37 @@ // specific language governing permissions and limitations // under the License. +use crate::memory_pool::reclaimer::reclaimer_state; use crate::memory_pool::{ - MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size, + MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation, + human_readable_size, }; use datafusion_common::HashMap; use datafusion_common::{DataFusionError, Result, resources_datafusion_err}; use log::debug; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use std::fmt::{Display, Formatter}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; use std::{ num::NonZeroUsize, - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::{AtomicU8, AtomicUsize, Ordering}, }; +/// How long [`TrackConsumersPool::try_grow_async`] waits for an +/// in-flight sibling to finish reclaiming before retrying. Kept short +/// so we don't stall the requestor longer than the typical reclaim +/// (mpsc send + spill commit). +const RECLAIM_RETRY_SLEEP: Duration = Duration::from_millis(50); + +/// Maximum number of times [`TrackConsumersPool::try_grow_async`] +/// retries the candidate walk while siblings are still in-flight. +/// Bounds the total wait at `MAX_RECLAIM_RETRIES * RECLAIM_RETRY_SLEEP` +/// so a livelock surfaces as OOM rather than a hang. +const MAX_RECLAIM_RETRIES: usize = 3; + /// A [`MemoryPool`] that enforces no limit #[derive(Debug, Default)] pub struct UnboundedMemoryPool { @@ -324,6 +342,51 @@ struct TrackedConsumer { can_spill: bool, reserved: AtomicUsize, peak: AtomicUsize, + reclaimer: Option>, + /// Tri-state eligibility flag for [`reclaimer`], encoded per + /// [`reclaimer_state`]. The pool flips `AVAILABLE` ↔ `IN_FLIGHT` + /// for dedup; the reclaimer's owner may sticky-set `DISABLED` once + /// it can no longer free memory. Shared `Arc` so the reclaimer + /// side and the pool see the same cell. `None` reclaimer ⇒ flag + /// is unused but still allocated. + reclaimer_state: Arc, +} + +/// RAII guard for the [`IN_FLIGHT`] slot of a [`TrackedConsumer`]'s +/// `reclaimer_state` flag. `Drop` only restores `AVAILABLE` if the +/// state is still `IN_FLIGHT` — leaves a sticky `DISABLED` alone. +/// +/// [`IN_FLIGHT`]: reclaimer_state::IN_FLIGHT +struct ReclaimerStateGuard { + flag: Arc, +} + +impl Drop for ReclaimerStateGuard { + fn drop(&mut self) { + let _ = self.flag.compare_exchange( + reclaimer_state::IN_FLIGHT, + reclaimer_state::AVAILABLE, + Ordering::AcqRel, + Ordering::Relaxed, + ); + } +} + +impl ReclaimerStateGuard { + /// Try to transition the flag from `AVAILABLE` to `IN_FLIGHT`. + /// Fails on contention or on a sticky `DISABLED`. + fn try_acquire(flag: &Arc) -> Option { + flag.compare_exchange( + reclaimer_state::AVAILABLE, + reclaimer_state::IN_FLIGHT, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .ok() + .map(|_| Self { + flag: Arc::clone(flag), + }) + } } impl TrackedConsumer { @@ -339,9 +402,29 @@ impl TrackedConsumer { /// Grows the tracked consumer's reserved size, /// should be called after the pool has successfully performed the grow(). + /// + /// Uses the value `reserved` definitely held immediately after this + /// thread's `fetch_add` as the peak candidate, then bumps `peak` via a + /// monotone-max CAS loop. This avoids the race in the previous + /// `peak.fetch_max(self.reserved())` form, where a concurrent `shrink` + /// between the load of `reserved` and the max-write to `peak` could + /// record a peak below the true high-water mark. fn grow(&self, additional: usize) { - self.reserved.fetch_add(additional, Ordering::Relaxed); - self.peak.fetch_max(self.reserved(), Ordering::Relaxed); + let prev = self.reserved.fetch_add(additional, Ordering::Relaxed); + let new = prev + additional; + + let mut peak = self.peak.load(Ordering::Relaxed); + while peak < new { + match self.peak.compare_exchange_weak( + peak, + new, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => peak = actual, + } + } } /// Reduce the tracked consumer's reserved size, @@ -407,8 +490,19 @@ pub struct TrackConsumersPool { inner: I, /// The amount of consumers to report(ordered top to bottom by reservation size) top: NonZeroUsize, - /// Maps consumer_id --> TrackedConsumer - tracked_consumers: Mutex>, + /// Cap on the number of reclaim candidates considered per + /// [`try_grow_async`] call. Bounds reclaim work when many consumers + /// are registered. Defaults to 4; override with + /// [`Self::with_reclaim_candidate_limit`]. + reclaim_candidate_limit: NonZeroUsize, + /// Maps consumer_id --> TrackedConsumer. + /// + /// Protected by an [`RwLock`] rather than a [`Mutex`]: registration + /// (insert) and unregistration (remove) take the write lock; grow, + /// shrink, try_grow, metrics, and report_top take the read lock and run + /// concurrently. The per-consumer [`AtomicUsize`] fields are mutated + /// under the shared read lock — see [`TrackedConsumer::grow`]. + tracked_consumers: RwLock>, } impl Display for TrackConsumersPool { @@ -464,10 +558,18 @@ impl TrackConsumersPool { Self { inner, top, + reclaim_candidate_limit: NonZeroUsize::new(4).unwrap(), tracked_consumers: Default::default(), } } + /// Override the cap on reclaim candidates considered per + /// [`try_grow_async`] call (default `4`). + pub fn with_reclaim_candidate_limit(mut self, n: NonZeroUsize) -> Self { + self.reclaim_candidate_limit = n; + self + } + /// Returns a reference to the wrapped inner [`MemoryPool`]. pub fn inner(&self) -> &I { &self.inner @@ -476,7 +578,7 @@ impl TrackConsumersPool { /// Returns a snapshot of all currently tracked consumers. pub fn metrics(&self) -> Vec { self.tracked_consumers - .lock() + .read() .values() .map(Into::into) .collect() @@ -486,7 +588,7 @@ impl TrackConsumersPool { pub fn report_top(&self, top: usize) -> String { let mut consumers = self .tracked_consumers - .lock() + .read() .iter() .map(|(consumer_id, tracked_consumer)| { ( @@ -525,7 +627,17 @@ impl MemoryPool for TrackConsumersPool { fn register(&self, consumer: &MemoryConsumer) { self.inner.register(consumer); - let mut guard = self.tracked_consumers.lock(); + let reclaimer = consumer.reclaimer().cloned(); + // Reuse the reclaimer's own flag when it provides one — that + // way the reclaimer side can sticky-set `DISABLED` and the + // pool sees it on the next filter pass. Otherwise allocate a + // fresh `AVAILABLE` flag for in-flight dedup only. + let state = reclaimer + .as_ref() + .and_then(|r| r.reclaimer_state()) + .unwrap_or_else(|| Arc::new(AtomicU8::new(reclaimer_state::AVAILABLE))); + + let mut guard = self.tracked_consumers.write(); let existing = guard.insert( consumer.id(), TrackedConsumer { @@ -533,6 +645,8 @@ impl MemoryPool for TrackConsumersPool { can_spill: consumer.can_spill(), reserved: Default::default(), peak: Default::default(), + reclaimer, + reclaimer_state: state, }, ); @@ -544,27 +658,29 @@ impl MemoryPool for TrackConsumersPool { fn unregister(&self, consumer: &MemoryConsumer) { self.inner.unregister(consumer); - self.tracked_consumers.lock().remove(&consumer.id()); + self.tracked_consumers.write().remove(&consumer.id()); } fn grow(&self, reservation: &MemoryReservation, additional: usize) { self.inner.grow(reservation, additional); - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.grow(additional); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } } fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { self.inner.shrink(reservation, shrink); - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.shrink(shrink); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.shrink(shrink); + } } fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { @@ -584,15 +700,185 @@ impl MemoryPool for TrackConsumersPool { _ => e, })?; - self.tracked_consumers - .lock() - .entry(reservation.consumer().id()) - .and_modify(|tracked_consumer| { - tracked_consumer.grow(additional); - }); + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } Ok(()) } + fn try_grow_async<'a>( + &'a self, + reservation: &'a MemoryReservation, + additional: usize, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + // Fast path. + let initial_err = match self.try_grow(reservation, additional) { + Ok(()) => return Ok(()), + Err(e) => e, + }; + + // We deliberately do NOT mark the requestor as IN_FLIGHT + // here. Doing so would cause N walkers running concurrently + // (e.g. many sort partitions all hitting `try_grow_async` + // in lock-step) to all set themselves IN_FLIGHT, at which + // point the candidate filter rejects every one of them and + // no walk has any victims — leading to spurious OOM after + // the retry budget elapses. Cycle defenses are still in + // place: the candidate filter rejects the requestor by id; + // `strictly-larger` (for reclaimer requestors) breaks + // size-ordered A↔B mutual-reclaim cycles; and a caller + // that drains its own `reclaim_rx` while awaiting + // `try_grow_async` (e.g. via a `select!`) cooperatively + // services any recursive reclaim that targets it. + let requestor_id = reservation.consumer().id(); + let requestor_has_reclaimer = self + .tracked_consumers + .read() + .get(&requestor_id) + .map(|tc| tc.reclaimer.is_some()) + .unwrap_or(false); + + let mut retries: usize = 0; + loop { + // Snapshot reclaimers. When the requestor has its own reclaimer, + // only consumers strictly larger than the requestor are + // eligible: smaller-or-equal siblings would free less than the + // requestor itself can, so the requestor should self-spill + // instead. This size-ordering also breaks mutual-reclaim + // cycles among reclaimable consumers (A targets B while B + // targets A) — at most one side of any pair can hold strictly + // more memory, so the other side has no candidates and + // surfaces an error for the caller's self-spill fallback. + // + // When the requestor has no reclaimer it cannot self-spill, + // so this filter is skipped — any positive sibling is a + // valid victim. No cycle is possible through a no-reclaimer + // requestor: it can never be selected as a candidate (the + // `reclaimer.as_ref()?` guard below rejects it), so no other + // walk can target it. + // + // Filter out anyone whose `reclaimer_state` flag is not + // `AVAILABLE` (in-flight or sticky-disabled). Also count + // IN_FLIGHT siblings so we know whether to wait briefly for + // them to finish before giving up. Drop the read guard + // before awaiting any reclaim. + let requestor_reserved = { + let guard = self.tracked_consumers.read(); + guard + .get(&requestor_id) + .map(|tc| tc.reserved()) + .unwrap_or(0) + }; + let mut in_flight_seen: usize = 0; + let mut candidates: Vec<( + usize, + Arc, + Arc, + )> = { + let guard = self.tracked_consumers.read(); + guard + .iter() + .filter_map(|(cid, tc)| { + if *cid == requestor_id { + return None; + } + // Track in-flight siblings (any size) so we can + // decide whether a retry has any chance of helping. + let state = tc.reclaimer_state.load(Ordering::Acquire); + if state == reclaimer_state::IN_FLIGHT { + in_flight_seen += 1; + } + let reclaimer = tc.reclaimer.as_ref()?; + if requestor_has_reclaimer + && tc.reserved() <= requestor_reserved + { + return None; + } + if state != reclaimer_state::AVAILABLE { + return None; + } + Some(( + tc.reserved(), + Arc::clone(reclaimer), + Arc::clone(&tc.reclaimer_state), + )) + }) + .collect() + }; + // Order: priority desc, then reservation size desc. + candidates.sort_by(|(lr, l, _), (rr, r, _)| { + r.priority().cmp(&l.priority()).then_with(|| rr.cmp(lr)) + }); + // Cap reclaim work — only consider the top-ranked candidates. + candidates.truncate(self.reclaim_candidate_limit.get()); + + // For each candidate: try to claim its in-flight slot + // (skip on contention or sticky-disabled so we work on a + // different victim rather than serializing behind a + // sibling's reclaim); re-check `try_grow` before reclaiming + // in case a sibling already freed enough; reclaim; retry + // `try_grow`. The retry path goes through `self.try_grow`, + // which already updates the tracked consumer's atomic + // reservation — no manual accounting needed here. + for (_, reclaimer, flag) in candidates { + let _g = match ReclaimerStateGuard::try_acquire(&flag) { + Some(g) => g, + None => continue, + }; + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + if let Err(e) = reclaimer.reclaim(additional).await { + debug!("memory reclaimer returned error: {e}"); + continue; + } + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + } + + // Walk produced nothing usable. If other consumers are + // currently reclaiming for someone else, their freed bytes + // may land in the pool shortly — wait briefly and retry + // before falling through to OOM. Bounded so we don't stall + // forever on a livelock. + if in_flight_seen > 0 && retries < MAX_RECLAIM_RETRIES { + retries += 1; + tokio::time::sleep(RECLAIM_RETRY_SLEEP).await; + // Quick fast-path retry: an in-flight sibling may have + // freed bytes during the sleep. + if self.try_grow(reservation, additional).is_ok() { + return Ok(()); + } + continue; + } + break; + } + + // Fall through to the inner pool's own reclaim path, if any. + // The default impl just re-runs `inner.try_grow`, which + // bypasses `TrackConsumersPool::try_grow`, so the + // consumer-side update is still required. + self.inner + .try_grow_async(reservation, additional) + .await + .map_err(|_| initial_err)?; + if let Some(tracked) = self + .tracked_consumers + .read() + .get(&reservation.consumer().id()) + { + tracked.grow(additional); + } + Ok(()) + }) + } + fn reserved(&self) -> usize { self.inner.reserved() } @@ -1046,4 +1332,126 @@ mod tests { "TrackConsumersPool Display" ); } + + /// N threads each call `grow(STEP)` then `shrink(STEP)` once on the same + /// consumer. Final `reserved == 0`. Peak hit at least once and at most + /// `THREADS * STEP` — validates that `fetch_add` on `reserved` is correct + /// under concurrent readers of the `RwLock`-protected map. + #[test] + fn test_tracked_consumer_concurrent_grow() { + const THREADS: usize = 16; + const STEP: usize = 7; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + let r = Arc::new(MemoryConsumer::new("c").register(&pool)); + + std::thread::scope(|s| { + for _ in 0..THREADS { + let r = Arc::clone(&r); + s.spawn(move || { + let local = r.new_empty(); + local.grow(STEP); + local.shrink(STEP); + }); + } + }); + + let metrics = tracked.metrics(); + let entry = metrics.iter().find(|m| m.name == "c").unwrap(); + assert_eq!(entry.reserved, 0); + assert!(entry.peak >= STEP); + assert!(entry.peak <= THREADS * STEP); + } + + /// N threads run interleaved `grow`/`shrink` pairs on the same consumer. + /// Final `reserved` must be 0; `peak` must be at least `STEP` (any grow + /// records its own bump) and at most `THREADS * STEP`. Validates the + /// monotone-max CAS on `peak`, fixing today's `fetch_max(self.reserved())` + /// race where an intervening shrink could drop `reserved` below the value + /// used to bump `peak`. + #[test] + fn test_tracked_consumer_concurrent_peak_monotone() { + const THREADS: usize = 16; + const ITERS: usize = 10_000; + const STEP: usize = 3; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + let r = Arc::new(MemoryConsumer::new("c").register(&pool)); + + std::thread::scope(|s| { + for _ in 0..THREADS { + let r = Arc::clone(&r); + s.spawn(move || { + let local = r.new_empty(); + for _ in 0..ITERS { + local.grow(STEP); + local.shrink(STEP); + } + }); + } + }); + + let entry = tracked + .metrics() + .into_iter() + .find(|m| m.name == "c") + .unwrap(); + assert_eq!(entry.reserved, 0, "all grows undone by shrinks"); + assert!(entry.peak >= STEP); + assert!(entry.peak <= THREADS * STEP); + } + + /// One thread loops register/unregister, another loops grow/shrink on a + /// stable consumer. Verifies no panics or deadlocks across the `RwLock` + /// boundary, and that the stable consumer's accounting is preserved + /// when a writer briefly takes the exclusive lock. + #[test] + fn test_tracked_consumers_pool_register_grow_concurrent() { + const ITERS: usize = 1_000; + + let tracked = Arc::new(TrackConsumersPool::new( + UnboundedMemoryPool::default(), + NonZeroUsize::new(5).unwrap(), + )); + let tracked_clone = Arc::clone(&tracked); + let pool: Arc = tracked_clone; + + let r = Arc::new(MemoryConsumer::new("stable").register(&pool)); + + std::thread::scope(|s| { + let pool_w = Arc::clone(&pool); + s.spawn(move || { + for i in 0..ITERS { + let _churn = + MemoryConsumer::new(format!("churn-{i}")).register(&pool_w); + } + }); + + let r_inner = Arc::clone(&r); + s.spawn(move || { + let local = r_inner.new_empty(); + for _ in 0..ITERS { + local.grow(5); + local.shrink(5); + } + }); + }); + + let metrics = tracked.metrics(); + let stable = metrics.iter().find(|m| m.name == "stable").unwrap(); + assert_eq!(stable.reserved, 0); + assert!(stable.peak >= 5); + assert!(metrics.iter().all(|m| !m.name.starts_with("churn-"))); + drop(r); + } } diff --git a/datafusion/execution/src/memory_pool/reclaimer.rs b/datafusion/execution/src/memory_pool/reclaimer.rs new file mode 100644 index 0000000000000..135853bff5237 --- /dev/null +++ b/datafusion/execution/src/memory_pool/reclaimer.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Operator hook used by a [`MemoryPool`] to free memory before an +//! allocation fails. +//! +//! [`MemoryPool`]: super::MemoryPool + +use datafusion_common::Result; +use std::fmt::Debug; +use std::sync::Arc; +use std::sync::atomic::AtomicU8; + +/// Encoded values stored in the [`reclaimer_state`] tri-state. +/// +/// [`reclaimer_state`]: MemoryReclaimer::reclaimer_state +pub mod reclaimer_state { + /// Reclaimer is idle and may be selected as a victim. + pub const AVAILABLE: u8 = 0; + /// A pool task is currently driving `reclaim` on this reclaimer. + pub const IN_FLIGHT: u8 = 1; + /// Reclaimer has been retired (e.g. operator entered a phase where + /// it can no longer free memory). Sticky — never returns to + /// `AVAILABLE`. + pub const DISABLED: u8 = 2; +} + +/// Hook attached to a [`MemoryConsumer`] via +/// [`MemoryConsumer::with_reclaimer`]. On +/// [`MemoryPool::try_grow_async`] failure the pool walks registered +/// reclaimers in descending [`Self::priority`] and asks each to free bytes. +/// +/// Implementations MUST: +/// +/// - Spill data **before** shrinking the reservation, so reported bytes +/// are recoverable. +/// - Release bytes via [`MemoryReservation::shrink`] / +/// [`MemoryReservation::free`] and return at most `target`. +/// - Not call `try_grow*` on the pool — risks reentrancy/deadlock. +/// - Not capture `Arc` / `Arc` +/// (creates a cycle that blocks `unregister`); use channels or `Weak`. +/// +/// [`MemoryConsumer`]: super::MemoryConsumer +/// [`MemoryConsumer::with_reclaimer`]: super::MemoryConsumer::with_reclaimer +/// [`MemoryPool::try_grow_async`]: super::MemoryPool::try_grow_async +/// [`MemoryReservation::shrink`]: super::MemoryReservation::shrink +/// [`MemoryReservation::free`]: super::MemoryReservation::free +#[async_trait::async_trait] +pub trait MemoryReclaimer: Send + Sync + Debug { + /// Upper bound on bytes this reclaimer can free. `None` = unknown. + fn reclaimable_bytes(&self) -> Option { + None + } + + /// Free up to `target` bytes; return the amount actually released. + /// See trait-level contract. + async fn reclaim(&self, target: usize) -> Result; + + /// Higher priorities reclaim first. Negative = last resort. + fn priority(&self) -> i32 { + 0 + } + + /// Optional shared tri-state flag controlling whether the pool + /// currently considers this reclaimer eligible. Values are defined + /// in [`reclaimer_state`]. Returning `Some(arc)` lets the + /// reclaimer's owner flip itself to `DISABLED` once it can no + /// longer free memory (e.g., on entering a merge phase), which + /// the pool observes immediately. Returning `None` lets the pool + /// allocate its own private flag — used only for in-flight dedup. + fn reclaimer_state(&self) -> Option> { + None + } +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6c02af8dec6d3..4ece6d7dca119 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -65,7 +65,9 @@ use datafusion_common::{ unwrap_or_internal_err, }; use datafusion_execution::TaskContext; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryReclaimer, MemoryReservation, reclaimer_state, +}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr::PhysicalExpr; @@ -74,6 +76,34 @@ use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; +/// Reclaimer for an [`ExternalSorter`] partition. Hands a oneshot off to +/// the partition's stream loop (the sorter's sole owner), which spills and +/// replies with the freed byte count. +#[derive(Debug)] +struct ExternalSorterReclaimer { + tx: tokio::sync::mpsc::Sender>, + /// Shared with the pool's `TrackedConsumer` entry. Stream loop + /// flips it to `DISABLED` on merge entry so the pool stops + /// targeting this consumer. + reclaimer_state: Arc, +} + +#[async_trait::async_trait] +impl MemoryReclaimer for ExternalSorterReclaimer { + async fn reclaim(&self, _target: usize) -> Result { + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + // Stream loop terminated, or response dropped: report 0. + if self.tx.send(resp_tx).await.is_err() { + return Ok(0); + } + Ok(resp_rx.await.unwrap_or(0)) + } + + fn reclaimer_state(&self) -> Option> { + Some(Arc::clone(&self.reclaimer_state)) + } +} + struct ExternalSorterMetrics { /// metrics baseline: BaselineMetrics, @@ -279,11 +309,16 @@ impl ExternalSorter { spill_compression: SpillCompression, metrics: &ExecutionPlanMetricsSet, runtime: Arc, + // Reclaimer attached to this partition's `MemoryConsumer`. + reclaimer: Option>, ) -> Result { let metrics = ExternalSorterMetrics::new(metrics, partition_id); - let reservation = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]")) - .with_can_spill(true) - .register(&runtime.memory_pool); + let mut consumer = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]")) + .with_can_spill(true); + if let Some(r) = reclaimer { + consumer = consumer.with_reclaimer(r); + } + let reservation = consumer.register(&runtime.memory_pool); let merge_reservation = MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) @@ -315,13 +350,20 @@ impl ExternalSorter { /// Appends an unsorted [`RecordBatch`] to `in_mem_batches` /// - /// Updates memory usage metrics, and possibly triggers spilling to disk + /// Updates memory usage metrics, and possibly triggers spilling to disk. + /// + /// The live `(false, None)` path in [`SortExec::execute`] inlines this + /// logic so that the inner `try_grow_async` await runs inside a + /// `select!` that also drains `reclaim_rx` — see the closure in + /// `execute`. This method is retained for tests that drive an + /// `ExternalSorter` directly without a reclaim channel. + #[cfg(test)] async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> { if input.num_rows() == 0 { return Ok(()); } - self.reserve_memory_for_merge()?; + self.reserve_memory_for_merge(); self.reserve_memory_for_batch_and_maybe_spill(&input) .await?; @@ -492,6 +534,10 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; let sorted_size = get_reserved_bytes_for_record_batch(&batch)?; + // Sync `try_grow`, not `try_grow_async`: we are already in the + // spill path (freeing memory). A recursive reclaim here can + // close a cycle between two sorters that are each waiting on + // the other's spill to complete. if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -521,7 +567,7 @@ impl ExternalSorter { ); // Reserve headroom for next sort/merge - self.reserve_memory_for_merge()?; + self.reserve_memory_for_merge(); Ok(()) } @@ -711,39 +757,48 @@ impl ExternalSorter { Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } - /// If this sort may spill, pre-allocates - /// `sort_spill_reservation_bytes` of memory to guarantee memory - /// left for the in memory sort/merge. - fn reserve_memory_for_merge(&mut self) -> Result<()> { - // Reserve headroom for next merge sort + /// Pre-allocates `sort_spill_reservation_bytes` of merge headroom + /// as a best-effort optimization. If the pool is full, the grow is + /// silently skipped — the merge phase will grow `merge_reservation` + /// lazily via `try_grow_async`. Using a fallible sync grow here is + /// intentional: under contention (e.g. 16 partitions all finishing + /// their pre-merge spill at once), an async grow would walk and + /// cascade with every other partition's walks, all timing out + /// against each other's `_self_guard`s. + fn reserve_memory_for_merge(&mut self) { if self.runtime.disk_manager.tmp_files_enabled() { - let size = self.sort_spill_reservation_bytes; - if self.merge_reservation.size() != size { - self.merge_reservation - .try_resize(size) - .map_err(Self::err_with_oom_context)?; + let target = self.sort_spill_reservation_bytes; + let current = self.merge_reservation.size(); + if current > target { + self.merge_reservation.shrink(current - target); + } else if current < target { + let _ = self.merge_reservation.try_grow(target - current); } } - - Ok(()) } /// Reserves memory to be able to accommodate the given batch. /// If memory is scarce, tries to spill current in-memory batches to disk first. + /// + /// Only called from [`Self::insert_batch`], which itself only runs + /// in tests now — see comment there. + #[cfg(test)] async fn reserve_memory_for_batch_and_maybe_spill( &mut self, input: &RecordBatch, ) -> Result<()> { let size = get_reserved_bytes_for_record_batch(input)?; - match self.reservation.try_grow(size) { + match self.reservation.try_grow_async(size).await { Ok(_) => Ok(()), Err(e) => { if self.in_mem_batches.is_empty() { return Err(Self::err_with_oom_context(e)); } - // Spill and try again. + // Sibling reclaim was already attempted by `try_grow_async` + // (which skips this consumer). Spill our own buffer, retry + // sync — siblings won't free more on a second pass. self.sort_and_spill_in_mem_batches().await?; self.reservation .try_grow(size) @@ -1246,6 +1301,18 @@ impl ExecutionPlan for SortExec { ))) } (false, None) => { + // Spill-request channel; drained by the stream loop below. + let (reclaim_tx, mut reclaim_rx) = + tokio::sync::mpsc::channel::>(4); + let state = Arc::new(std::sync::atomic::AtomicU8::new( + reclaimer_state::AVAILABLE, + )); + let reclaimer: Arc = + Arc::new(ExternalSorterReclaimer { + tx: reclaim_tx, + reclaimer_state: Arc::clone(&state), + }); + let mut sorter = ExternalSorter::new( partition, input.schema(), @@ -1256,14 +1323,136 @@ impl ExecutionPlan for SortExec { context.session_config().spill_compression(), &self.metrics_set, context.runtime_env(), + Some(reclaimer), )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; + // State machine: spill or insert. The inner + // `try_grow_async` award for an incoming batch + // is wrapped in its own `select!` against + // `reclaim_rx` so that a sibling targeting this + // sorter for reclaim is serviced even while we + // ourselves are waiting for memory — otherwise + // a fat sorter sitting in `try_grow_async.await` + // would never drain its inbound reclaim queue, + // and a concurrent sibling's + // `reclaimer.reclaim(...).await` would never + // resolve. `biased` ensures reclaim wins over + // insert under pressure. + // + // Cancellation: selecting reclaim drops the in-flight + // `input.next()`. Safe for cancellation-safe inputs + // (channel receivers, e.g. RepartitionExec); other + // inputs could drop a batch here. + 'outer: loop { + tokio::select! { + biased; + Some(resp_tx) = reclaim_rx.recv() => { + // A reclaim can be dequeued just after a + // prior spill drained `in_mem_batches` + // (sibling sent during the spill's awaits; + // pool's zero-byte filter can transiently + // miss us via split reservations). Nothing + // local to free — reply 0 and keep going. + if sorter.in_mem_batches.is_empty() { + let _ = resp_tx.send(0); + continue; + } + let before = sorter.used(); + sorter.sort_and_spill_in_mem_batches().await?; + let after = sorter.used(); + let _ = resp_tx + .send(before.saturating_sub(after)); + } + next = input.next() => { + let Some(batch_result) = next else { + break 'outer; + }; + let batch = batch_result?; + if batch.num_rows() == 0 { + continue; + } + sorter.reserve_memory_for_merge(); + let size = + get_reserved_bytes_for_record_batch(&batch)?; + + // Reclaim-cooperative grow: while + // awaiting `try_grow_async`, keep + // draining `reclaim_rx`. On an + // inbound reclaim, spill our + // in-memory batches; that almost + // always frees enough that a sync + // `try_grow` succeeds, so we can + // skip the next async wait. + loop { + tokio::select! { + biased; + Some(resp_tx) = reclaim_rx.recv() => { + if sorter.in_mem_batches.is_empty() { + let _ = resp_tx.send(0); + continue; + } + let before = sorter.used(); + sorter + .sort_and_spill_in_mem_batches() + .await?; + let after = sorter.used(); + let _ = resp_tx + .send(before.saturating_sub(after)); + if sorter + .reservation + .try_grow(size) + .is_ok() + { + break; + } + } + r = sorter + .reservation + .try_grow_async(size) => { + match r { + Ok(()) => break, + Err(_) if !sorter + .in_mem_batches + .is_empty() => + { + sorter + .sort_and_spill_in_mem_batches() + .await?; + sorter + .reservation + .try_grow(size) + .map_err( + ExternalSorter::err_with_oom_context, + )?; + break; + } + Err(e) => { + return Err( + ExternalSorter::err_with_oom_context(e), + ); + } + } + } + } + } + sorter.in_mem_batches.push(batch); + } + } } + // Sticky-disable so concurrent `try_grow_async` + // callers stop targeting this consumer once we + // enter the merge phase. Set before dropping + // the receiver to close any window where the + // pool would observe `AVAILABLE` after the + // channel is gone (and hence get `Ok(0)` from a + // wasted `reclaim`). + state.store( + reclaimer_state::DISABLED, + std::sync::atomic::Ordering::Release, + ); + drop(reclaim_rx); sorter.sort().await }) .try_flatten(), @@ -2766,6 +2955,7 @@ mod tests { SpillCompression::Uncompressed, &metrics_set, Arc::clone(&runtime), + None, )?; // Insert enough data to force spilling.