diff --git a/rust/sedona/src/memory_pool.rs b/rust/sedona/src/memory_pool.rs index bf43508b8..35c30939d 100644 --- a/rust/sedona/src/memory_pool.rs +++ b/rust/sedona/src/memory_pool.rs @@ -54,14 +54,14 @@ pub struct SedonaFairSpillPool { #[derive(Debug)] struct FairSpillPoolState { - /// The number of consumers that can spill - num_spill: usize, - /// The total amount of memory reserved that can be spilled spillable: usize, /// The total amount of memory reserved by consumers that cannot spill unspillable: usize, + + /// Spillable memory reserved by each active spillable consumer + spillable_consumers: Vec<(usize, usize)>, } impl SedonaFairSpillPool { @@ -71,32 +71,66 @@ impl SedonaFairSpillPool { pool_size, unspillable_reserve_ratio, state: Mutex::new(FairSpillPoolState { - num_spill: 0, spillable: 0, unspillable: 0, + spillable_consumers: vec![], }), } } } -impl MemoryPool for SedonaFairSpillPool { - fn register(&self, consumer: &MemoryConsumer) { - if consumer.can_spill() { - self.state.lock().num_spill += 1; - } +fn grow_spillable_consumer( + state: &mut FairSpillPoolState, + consumer: &MemoryConsumer, + additional: usize, +) { + if additional == 0 { + return; } - fn unregister(&self, consumer: &MemoryConsumer) { - if consumer.can_spill() { - let mut state = self.state.lock(); - state.num_spill = state.num_spill.checked_sub(1).unwrap(); + let id = consumer.id(); + if let Some((_, size)) = state + .spillable_consumers + .iter_mut() + .find(|(consumer_id, _)| *consumer_id == id) + { + *size += additional; + } else { + state.spillable_consumers.push((id, additional)); + } +} + +fn shrink_spillable_consumer( + state: &mut FairSpillPoolState, + consumer: &MemoryConsumer, + shrink: usize, +) { + if shrink == 0 { + return; + } + + let id = consumer.id(); + if let Some(pos) = state + .spillable_consumers + .iter() + .position(|(consumer_id, _)| *consumer_id == id) + { + if state.spillable_consumers[pos].1 > shrink { + state.spillable_consumers[pos].1 -= shrink; + } else { + state.spillable_consumers.remove(pos); } } +} +impl MemoryPool for SedonaFairSpillPool { fn grow(&self, reservation: &MemoryReservation, additional: usize) { let mut state = self.state.lock(); match reservation.consumer().can_spill() { - true => state.spillable += additional, + true => { + state.spillable += additional; + grow_spillable_consumer(&mut state, reservation.consumer(), additional); + } false => state.unspillable += additional, } } @@ -104,7 +138,10 @@ impl MemoryPool for SedonaFairSpillPool { fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { let mut state = self.state.lock(); match reservation.consumer().can_spill() { - true => state.spillable -= shrink, + true => { + state.spillable -= shrink; + shrink_spillable_consumer(&mut state, reservation.consumer(), shrink); + } false => state.unspillable -= shrink, } } @@ -124,12 +161,23 @@ impl MemoryPool for SedonaFairSpillPool { match reservation.consumer().can_spill() { true => { - // No spiller may use more than their fraction of the memory available + // Apply fair shares across active spillers only. Idle spillable + // consumers are not counted until they reserve memory. + let consumer_reserved = state + .spillable_consumers + .iter() + .find(|(id, _)| *id == reservation.consumer().id()) + .map(|(_, size)| *size) + .unwrap_or(0); + let active_spill = + state.spillable_consumers.len() + usize::from(consumer_reserved == 0); let available = spill_available - .checked_div(state.num_spill) + .checked_div(active_spill) .unwrap_or(spill_available); + let total_available = spill_available.saturating_sub(state.spillable); + let available = available.min(consumer_reserved + total_available); - if reservation.size() + additional > available { + if consumer_reserved + additional > available { return Err(insufficient_capacity_err( reservation, additional, @@ -139,6 +187,7 @@ impl MemoryPool for SedonaFairSpillPool { )); } state.spillable += additional; + grow_spillable_consumer(&mut state, reservation.consumer(), additional); } false => { let available = self @@ -250,9 +299,9 @@ mod tests { } #[test] - fn test_fairness_among_spillers() { - // Pool size 100, 0% reserved. - let pool: Arc = Arc::new(SedonaFairSpillPool::new(100, 0.0)); + fn test_fairness_among_active_spillers() { + // Pool size 200, 0% reserved. + let pool: Arc = Arc::new(SedonaFairSpillPool::new(200, 0.0)); let c1 = MemoryConsumer::new("c1").with_can_spill(true); let mut r1 = c1.register(&pool); @@ -260,22 +309,17 @@ mod tests { let c2 = MemoryConsumer::new("c2").with_can_spill(true); let mut r2 = c2.register(&pool); - // With 2 spillers, each gets 50. - r1.try_grow(50).unwrap(); + // r2 is registered but idle, so r1 can use the full spill budget. + r1.try_grow(200).unwrap(); assert!(r1.try_grow(1).is_err()); - r2.try_grow(50).unwrap(); + // Once r2 reserves memory, fairness applies across active spillers. + r1.shrink(150); + assert!(r2.try_grow(101).is_err()); + r2.try_grow(100).unwrap(); assert!(r2.try_grow(1).is_err()); - // If one shrinks, other can't grow immediately if we strictly enforce N-way split? - // DataFusion FairSpillPool: - // let available = spill_available.checked_div(state.num_spill).unwrap_or(spill_available); - // Yes, it strictly enforces split. - - r1.shrink(50); - // r1 = 0, r2 = 50. - // r2 tries to grow. Available per spiller = 50. r2 has 50. - // So r2 cannot grow even if r1 is empty. This is how FairSpillPool works. - assert!(r2.try_grow(1).is_err()); + assert!(r1.try_grow(51).is_err()); + r1.try_grow(50).unwrap(); } }