From 825b26d180ab2e97fc17a33388f90cf8374c75f7 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sun, 3 May 2026 17:30:07 +0200 Subject: [PATCH 1/7] Draft: dynamic task count --- benchmarks/cdk/bin/datafusion-bench.ts | 5 + benchmarks/src/run.rs | 33 +- src/coordinator/distributed.rs | 98 ++++-- src/coordinator/mod.rs | 1 + src/coordinator/prepare_dynamic_plan.rs | 333 ++++++++++++++++++ src/coordinator/prepare_static_plan.rs | 3 +- src/coordinator/task_spawner.rs | 9 +- src/distributed_ext.rs | 32 ++ src/distributed_planner/distributed_config.rs | 2 + .../distributed_query_planner.rs | 14 +- .../inject_network_boundaries.rs | 154 ++++++-- src/distributed_planner/mod.rs | 4 + src/distributed_planner/network_boundary.rs | 21 +- .../prepare_network_boundaries.rs | 2 +- src/execution_plans/mod.rs | 2 + src/execution_plans/network_broadcast.rs | 20 +- src/execution_plans/network_coalesce.rs | 18 +- src/execution_plans/network_shuffle.rs | 20 +- src/execution_plans/sampler.rs | 322 +++++++++++++++++ src/metrics/task_metrics_rewriter.rs | 15 +- src/observability/service.rs | 2 +- src/protobuf/distributed_codec.rs | 21 +- src/worker/generated/worker.rs | 37 +- src/worker/impl_coordinator_channel.rs | 29 +- src/worker/impl_execute_task.rs | 5 +- src/worker/mod.rs | 3 +- src/worker/task_data.rs | 53 ++- src/worker/test_utils/worker_handles.rs | 3 +- src/worker/worker.proto | 26 ++ src/worker/worker_connection_pool.rs | 14 + tests/metrics_collection.rs | 32 ++ 31 files changed, 1207 insertions(+), 126 deletions(-) create mode 100644 src/coordinator/prepare_dynamic_plan.rs create mode 100644 src/execution_plans/sampler.rs diff --git a/benchmarks/cdk/bin/datafusion-bench.ts b/benchmarks/cdk/bin/datafusion-bench.ts index 77d43dd2..5aa9aa0c 100644 --- a/benchmarks/cdk/bin/datafusion-bench.ts +++ b/benchmarks/cdk/bin/datafusion-bench.ts @@ -24,6 +24,7 @@ async function main() { .option('--max-tasks-per-stage ', 'Max tasks per stage', '0') .option('--repartition-file-min-size ', 'repartition_file_min_size DF option', '10485760' /* upstream default */) .option('--target-partitions ', 'target_partitions DF option', '8') + .option('--dynamic ', 'Use the dynamic task count assigner', 'false') .option('--queries ', 'Specific queries to run', undefined) .option('--debug ', 'Print the generated plans to stdout') .option('--warmup ', 'Perform a warmup query before the benchmarks', 'true') @@ -46,6 +47,7 @@ async function main() { const childrenIsolatorUnions = options.childrenIsolatorUnions === 'true' || options.childrenIsolatorUnions === 1 const broadcastJoins = options.broadcastJoins === 'true' || options.broadcastJoins === 1 const partialReduce = options.partialReduce === 'true' || options.partialReduce === 1 + const dynamicTaskCount = options.dynamic === 'true' || options.dynamic === 1 const debug = options.debug === true || options.debug === 'true' || options.debug === 1 const warmup = options.warmup === true || options.warmup === 'true' || options.warmup === 1 @@ -59,6 +61,7 @@ async function main() { compression, broadcastJoins, partialReduce, + dynamicTaskCount, maxTasksPerStage, repartitionFileMinSize, targetPartitions @@ -97,6 +100,7 @@ class DataFusionRunner implements BenchmarkRunner { childrenIsolatorUnions: boolean; broadcastJoins: boolean; partialReduce: boolean; + dynamicTaskCount: boolean; maxTasksPerStage: number; repartitionFileMinSize: number; targetPartitions: number; @@ -176,6 +180,7 @@ class DataFusionRunner implements BenchmarkRunner { SET distributed.children_isolator_unions=${this.options.childrenIsolatorUnions}; SET distributed.broadcast_joins=${this.options.broadcastJoins}; SET distributed.partial_reduce=${this.options.partialReduce}; + SET distributed.dynamic_task_count=${this.options.dynamicTaskCount}; SET distributed.max_tasks_per_stage=${this.options.maxTasksPerStage}; SET datafusion.optimizer.repartition_file_min_size=${this.options.repartitionFileMinSize}; SET datafusion.execution.target_partitions=${this.options.targetPartitions}; diff --git a/benchmarks/src/run.rs b/benchmarks/src/run.rs index 45836f42..eee28117 100644 --- a/benchmarks/src/run.rs +++ b/benchmarks/src/run.rs @@ -24,11 +24,13 @@ use datafusion::common::utils::get_available_parallelism; use datafusion::common::{config_err, exec_err, not_impl_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::SessionStateBuilder; -use datafusion::physical_plan::display::DisplayableExecutionPlan; -use datafusion::physical_plan::{collect, displayable}; +use datafusion::physical_plan::collect; use datafusion::prelude::*; use datafusion_distributed::test_utils::localhost::LocalHostWorkerResolver; -use datafusion_distributed::{DistributedExt, NetworkBoundaryExt, SessionStateBuilderExt, Worker}; +use datafusion_distributed::{ + DistributedExt, DistributedMetricsFormat, NetworkBoundaryExt, SessionStateBuilderExt, Worker, + display_plan_ascii, rewrite_distributed_plan_with_metrics, +}; use datafusion_distributed_benchmarks::datasets::{clickbench, register_tables, tpcds, tpch}; use std::error::Error; use std::fs; @@ -112,6 +114,10 @@ pub struct RunOpt { #[structopt(short = "s", long = "batch-size")] batch_size: Option, + /// Dynamically assign tasks to stages based on runtime stats + #[structopt(long = "dynamic")] + dynamic: bool, + /// Activate debug mode to see more details #[structopt(short, long)] debug: bool, @@ -192,6 +198,7 @@ impl RunOpt { .with_distributed_broadcast_joins(self.broadcast_joins)? .with_distributed_metrics_collection(self.collect_metrics)? .with_distributed_max_tasks_per_stage(self.max_tasks_per_stage)? + .with_distributed_dynamic_task_count(self.dynamic)? .build(); let ctx = SessionContext::new_with_state(state); register_tables(&ctx, &self.get_path()?).await?; @@ -286,21 +293,8 @@ impl RunOpt { let plan = ctx.sql(sql).await?; let (state, plan) = plan.into_parts(); - if self.debug { - println!("=== Logical plan ===\n{plan}\n"); - } - let plan = state.optimize(&plan)?; - if self.debug { - println!("=== Optimized logical plan ===\n{plan}\n"); - } let physical_plan = state.create_physical_plan(&plan).await?; - if self.debug { - println!( - "=== Physical plan ===\n{}\n", - displayable(physical_plan.as_ref()).indent(true) - ); - } let mut n_tasks = 0; physical_plan.clone().transform_down(|node| { if let Some(node) = node.as_network_boundary() { @@ -310,9 +304,14 @@ impl RunOpt { })?; let result = collect(physical_plan.clone(), state.task_ctx()).await?; if self.debug { + let plan = rewrite_distributed_plan_with_metrics( + physical_plan, + DistributedMetricsFormat::Aggregated, + ) + .await?; println!( "=== Physical plan with metrics ===\n{}\n", - DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent(true) + display_plan_ascii(plan.as_ref(), true) ); } Ok((result, n_tasks)) diff --git a/src/coordinator/distributed.rs b/src/coordinator/distributed.rs index 1c472b9a..187fe3d4 100644 --- a/src/coordinator/distributed.rs +++ b/src/coordinator/distributed.rs @@ -1,5 +1,7 @@ +use crate::DistributedConfig; use crate::common::{require_one_child, serialize_uuid}; use crate::coordinator::metrics_store::MetricsStore; +use crate::coordinator::prepare_dynamic_plan::prepare_dynamic_plan; use crate::coordinator::prepare_static_plan::prepare_static_plan; use crate::distributed_planner::NetworkBoundaryExt; use crate::worker::generated::worker::TaskKey; @@ -27,22 +29,25 @@ use std::sync::Mutex; /// over the wire. #[derive(Debug)] pub struct DistributedExec { - plan: Arc, - prepared_plan: Arc>>>, + base_plan: Arc, + final_plan: Arc>>>, + head_stage: Arc>>>, metrics: ExecutionPlanMetricsSet, pub(crate) task_metrics: Option>, } pub(super) struct PreparedPlan { pub(super) head_stage: Arc, + pub(super) final_plan: Arc, pub(super) join_set: JoinSet>, } impl DistributedExec { - pub fn new(plan: Arc) -> Self { + pub fn new(base_plan: Arc) -> Self { Self { - plan, - prepared_plan: Arc::new(Mutex::new(None)), + base_plan, + final_plan: Arc::new(Mutex::new(None)), + head_stage: Arc::new(Mutex::new(None)), metrics: ExecutionPlanMetricsSet::new(), task_metrics: None, } @@ -69,7 +74,10 @@ impl DistributedExec { let Some(task_metrics) = &self.task_metrics else { return; }; - let _ = self.plan.apply(|plan| { + let Some(plan) = self.final_plan.lock().unwrap().as_ref().cloned() else { + return; + }; + let _ = plan.apply(|plan| { if let Some(boundary) = plan.as_network_boundary() { let stage = boundary.input_stage(); for i in 0..stage.task_count() { @@ -95,7 +103,7 @@ impl DistributedExec { /// It is updated on every call to `execute()`. Returns an error if `.execute()` has not been /// called. pub(crate) fn prepared_plan(&self) -> Result> { - self.prepared_plan + self.final_plan .lock() .map_err(|e| internal_datafusion_err!("Failed to lock prepared plan: {}", e))? .clone() @@ -103,6 +111,18 @@ impl DistributedExec { internal_datafusion_err!("No prepared plan found. Was execute() called?") }) } + + /// Returns the head stage that was actually executed. Unlike [`Self::prepared_plan`] (which is + /// reconstructed for visualization, with `Stage::Local` boundaries and rebuilt ancestor + /// `Arc`s), this returns the original `Arc` instances whose metrics were populated during + /// execution. + pub(crate) fn head_stage(&self) -> Result> { + self.head_stage + .lock() + .map_err(|e| internal_datafusion_err!("Failed to lock head stage: {}", e))? + .clone() + .ok_or_else(|| internal_datafusion_err!("No head stage found. Was execute() called?")) + } } impl DisplayAs for DistributedExec { @@ -121,11 +141,11 @@ impl ExecutionPlan for DistributedExec { } fn properties(&self) -> &Arc { - self.plan.properties() + self.base_plan.properties() } fn children(&self) -> Vec<&Arc> { - vec![&self.plan] + vec![&self.base_plan] } fn with_new_children( @@ -133,8 +153,9 @@ impl ExecutionPlan for DistributedExec { children: Vec>, ) -> Result> { Ok(Arc::new(DistributedExec { - plan: require_one_child(&children)?, - prepared_plan: self.prepared_plan.clone(), + base_plan: require_one_child(&children)?, + final_plan: Arc::new(Mutex::new(None)), + head_stage: Arc::new(Mutex::new(None)), metrics: self.metrics.clone(), task_metrics: self.task_metrics.clone(), })) @@ -155,31 +176,56 @@ impl ExecutionPlan for DistributedExec { ); } - let PreparedPlan { - head_stage, - join_set, - } = prepare_static_plan(&self.plan, &self.metrics, &self.task_metrics, &context)?; - { - let mut guard = self - .prepared_plan - .lock() - .map_err(|e| internal_datafusion_err!("Failed to lock prepared plan: {e}"))?; - *guard = Some(head_stage.clone()); - } + let this = Self { + base_plan: Arc::clone(&self.base_plan), + final_plan: Arc::clone(&self.final_plan), + head_stage: Arc::clone(&self.head_stage), + metrics: self.metrics.clone(), + task_metrics: self.task_metrics.clone(), + }; + let mut builder = RecordBatchReceiverStreamBuilder::new(self.schema(), 1); let tx = builder.tx(); // Spawn the task that pulls data from child... builder.spawn(async move { + let d_cfg = DistributedConfig::from_config_options(context.session_config().options())?; + + let PreparedPlan { + head_stage, + final_plan, + join_set, + } = match d_cfg.dynamic_task_count { + false => prepare_static_plan( + &this.base_plan, + &this.metrics, + &this.task_metrics, + &context, + )?, + true => { + prepare_dynamic_plan( + &this.base_plan, + &this.metrics, + &this.task_metrics, + &context, + ) + .await? + } + }; + + this.final_plan + .lock() + .expect("poisoned lock") + .replace(final_plan); + this.head_stage + .lock() + .expect("poisoned lock") + .replace(Arc::clone(&head_stage)); let mut stream = head_stage.execute(partition, context)?; while let Some(msg) = stream.next().await { if tx.send(msg).await.is_err() { break; // channel closed } } - Ok(()) - }); - // ...in parallel to the one that feeds the plan to workers. - builder.spawn(async move { for res in join_set.join_all().await { res?; } diff --git a/src/coordinator/mod.rs b/src/coordinator/mod.rs index 2aea8442..db8c6d01 100644 --- a/src/coordinator/mod.rs +++ b/src/coordinator/mod.rs @@ -1,5 +1,6 @@ mod distributed; mod metrics_store; +mod prepare_dynamic_plan; mod prepare_static_plan; mod task_spawner; diff --git a/src/coordinator/prepare_dynamic_plan.rs b/src/coordinator/prepare_dynamic_plan.rs new file mode 100644 index 00000000..2a210b24 --- /dev/null +++ b/src/coordinator/prepare_dynamic_plan.rs @@ -0,0 +1,333 @@ +use crate::coordinator::MetricsStore; +use crate::coordinator::distributed::PreparedPlan; +use crate::coordinator::task_spawner::{ + CoordinatorToWorkerMetrics, CoordinatorToWorkerTaskSpawner, +}; +use crate::distributed_planner::{ + NetworkBoundaryBuilderResult, get_distributed_task_estimator, inject_network_boundaries, + network_boundary_inject_sampler, +}; +use crate::stage::{LocalStage, RemoteStage}; +use crate::worker::generated::worker as pb; +use crate::{ + DistributedCodec, NetworkBoundary, NetworkBoundaryExt, NetworkCoalesceExec, Stage, + TaskCountAnnotation, TaskRoutingContext, get_distributed_worker_resolver, +}; +use dashmap::DashMap; +use datafusion::common::instant::Instant; +use datafusion::common::runtime::JoinSet; +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion::common::{Result, exec_err}; +use datafusion::config::ConfigOptions; +use datafusion::execution::TaskContext; +use datafusion::physical_expr_common::metrics::ExecutionPlanMetricsSet; +use datafusion::physical_plan::ExecutionPlan; +use futures::{Stream, StreamExt}; +use rand::Rng; +use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio_stream::wrappers::UnboundedReceiverStream; +use url::Url; + +pub(super) async fn prepare_dynamic_plan( + base_plan: &Arc, + metrics: &ExecutionPlanMetricsSet, + task_metrics: &Option>, + ctx: &Arc, +) -> Result { + let metrics = CoordinatorToWorkerMetrics::new(metrics); + + let worker_idx = AtomicUsize::new(rand::rng().random_range(0..100)); // TODO + let plans_for_viz = PlanReconstructor::default(); + let outer_join_set = Mutex::new(JoinSet::new()); + + let head_stage = inject_network_boundaries( + Arc::clone(base_plan), + |nb: Arc, _cfg: &ConfigOptions| { + let worker_resolver = get_distributed_worker_resolver(ctx.session_config())?; + let codec = DistributedCodec::new_combined_with_user(ctx.session_config()); + let task_estimator = get_distributed_task_estimator(ctx.session_config())?; + let mut join_set = JoinSet::new(); + let Stage::Local(input_stage) = nb.input_stage() else { + return exec_err!("NetworkBoundary's input stage was in remote mode."); + }; + let mut input_stage = input_stage.clone(); + input_stage.plan = network_boundary_inject_sampler(input_stage.plan)?; + let mut spawner = CoordinatorToWorkerTaskSpawner::new( + &input_stage, + &metrics, + task_metrics, + &codec, + &mut join_set, + )?; + + let urls = worker_resolver.get_urls()?; + let next_url = || urls[(worker_idx.fetch_add(1, SeqCst)) % urls.len()].clone(); + + let routed_urls = match task_estimator.route_tasks(&TaskRoutingContext { + task_ctx: Arc::clone(ctx), + plan: &input_stage.plan, + task_count: input_stage.tasks, + available_urls: &urls, + }) { + Ok(Some(routed_urls)) => routed_urls, + // If the user has not defined custom routing with a `route_tasks` implementation, we + // default to round-robin task assignation from a randomized starting point. + Ok(None) => (0..input_stage.tasks).map(|_| next_url()).collect(), + Err(e) => return exec_err!("error routing tasks to workers: {e}"), + }; + + if routed_urls.len() != input_stage.tasks { + return exec_err!( + "number of tasks ({}) was not equal to number of urls ({}) at execution time", + input_stage.tasks, + routed_urls.len() + ); + } + + let mut workers = Vec::with_capacity(input_stage.tasks); + let mut load_info_rxs = Vec::with_capacity(input_stage.tasks); + + let mut url = if input_stage.tasks == 1 { + get_child_stages_urls(&input_stage.plan)? + .iter() + .find_map(|v| match v.len() == 1 { + true => Some(v.first().cloned()), + false => None, + }) + .flatten() + .unwrap_or_else(next_url) + } else { + next_url() + }; + + for i in 0..input_stage.tasks { + workers.push(url.clone()); + // Spawns the task that feeds this subplan to this worker. There will be as + // many as this spawned tasks as workers. + let (tx, worker_rx) = spawner.send_plan_task(Arc::clone(ctx), i, url)?; + load_info_rxs.push({ + let rx = spawner.load_info_and_metrics_collection_task(i, worker_rx); + // Tag each LoadInfoBatch with the producer task index so + // `calculate_task_count` can identify (task_idx, partition) slices + // independently — `select_all` would otherwise collapse them. + UnboundedReceiverStream::new(rx).map(move |batch| (i, batch)) + }); + spawner.work_unit_feed_task(Arc::clone(ctx), i, tx)?; + url = next_url(); + } + + outer_join_set + .lock() + .expect("poisoned lock") + .spawn(async move { + for result in join_set.join_all().await { + result?; + } + Ok(()) + }); + + plans_for_viz.insert(input_stage.num, Arc::clone(&input_stage.plan)); + + let nb = nb.with_input_stage(Stage::Remote(RemoteStage { + query_id: input_stage.query_id, + num: input_stage.num, + workers, + }))?; + + let load_info_stream = futures::stream::select_all(load_info_rxs); + + Ok(async move { + let task_count_above = if nb.as_any().is::() { + TaskCountAnnotation::Maximum(1) + } else { + TaskCountAnnotation::Desired(calculate_task_count(load_info_stream).await) + }; + Ok(NetworkBoundaryBuilderResult { + task_count_above, + network_boundary: nb, + }) + }) + }, + ctx.session_config().options(), + ) + .await?; + Ok(PreparedPlan { + final_plan: plans_for_viz.reconstruct(&head_stage)?, + head_stage, + join_set: std::mem::take(&mut outer_join_set.lock().unwrap()), + }) +} + +#[derive(Default)] +struct PlanReconstructor { + stage_map: DashMap>, +} + +impl PlanReconstructor { + fn insert(&self, stage: usize, plan: Arc) { + self.stage_map.insert(stage, plan); + } + + fn reconstruct(&self, head_stage: &Arc) -> Result> { + let reconstructed = Arc::clone(head_stage).transform_down(|plan| { + let Some(nb) = plan.as_network_boundary() else { + return Ok(Transformed::no(plan)); + }; + let input_stage = nb.input_stage(); + let Some(plan_for_viz) = self.stage_map.get(&input_stage.num()) else { + return exec_err!( + "Failed to retrieve plan for stage {} for visualization purposes", + input_stage.num() + ); + }; + + let nb = nb.with_input_stage(Stage::Local(LocalStage { + query_id: input_stage.query_id(), + num: input_stage.num(), + plan: Arc::clone(&plan_for_viz), + tasks: input_stage.task_count(), + }))?; + + Ok(Transformed::yes(nb)) + })?; + Ok(reconstructed.data) + } +} + +fn get_child_stages_urls( + plan: &Arc, +) -> Result>> { + let mut result = vec![]; + plan.apply(|plan| { + let Some(nb) = plan.as_network_boundary() else { + return Ok(TreeNodeRecursion::Continue); + }; + + match nb.input_stage() { + Stage::Local(_) => exec_err!("While gathering child stages URLs, one was in local mode. This is a bug in the dynamic task count execution logic, please report it.")?, + Stage::Remote(remote) => result.push(&remote.workers) + } + + Ok(TreeNodeRecursion::Jump) + })?; + + Ok(result) +} + +/// Estimates the next stage's task count from per-(task, partition) sampler observations. +/// +/// Each producer `(task_idx, partition)` is treated as an independent slice that observes a +/// fraction of the stage's output. The function samples for [`SAMPLING_WINDOW`] of wall-clock +/// time after the first message arrives, then has each observed slice cast a vote: the task +/// count it would assign to the consumer if every observed slice produced at this slice's +/// observed velocity. The final task count is the **median** of votes — robust to a handful +/// of skewed (very fast or very slow) producers. +/// +/// Returns early as soon as every observed slice has emitted a terminating signal +/// (`max_memory_reached` or `eos`); otherwise exits at the sampling deadline. Returns 1 if +/// no slice could compute a usable velocity. +async fn calculate_task_count( + mut load_info_stream: impl Stream + Unpin, +) -> usize { + /// Target sustained throughput per downstream task. The next stage is sized so each + /// task is expected to absorb roughly this many bytes per second of producer output. + const TARGET_BYTES_PER_SEC_PER_TASK: u64 = 64 * 1024 * 1024; + /// Per-`(task_idx, partition)` cap on buffered `LoadInfo` messages. Once a slice has + /// produced this many messages it is considered done — enough information has been seen + /// to estimate its velocity. + const MAX_MESSAGES_PER_SLICE: usize = 2; + /// Minimum number of slices that must reach the "done" state before voting. A slice is + /// done when it signals `eos`, `max_memory_reached`, or hits `MAX_MESSAGES_PER_SLICE`. + const TARGET_DONE_SLICES: usize = 2; + /// Wall-clock safety net measured from the first received `LoadInfo`. If neither + /// `TARGET_DONE_SLICES` nor full slice coverage is reached within this window, vote with + /// whatever has been observed. Prevents deadlock if a stage has fewer slices than + /// `TARGET_DONE_SLICES` but new slices are still appearing slowly. + const SAMPLING_WINDOW: Duration = Duration::from_millis(25); + + #[derive(Default)] + struct Slice { + total_bytes: u64, + max_elapsed_ns: u64, + msg_count: usize, + done: bool, + } + let mut slices: HashMap<(usize, u64), Slice> = HashMap::new(); + let mut done_count: usize = 0; + let mut deadline: Option = None; + + loop { + let next = match deadline { + None => load_info_stream.next().await, + Some(d) => match tokio::time::timeout_at(d.into(), load_info_stream.next()).await { + Ok(item) => item, + Err(_) => break, // sampling window elapsed + }, + }; + let Some((task_idx, batch)) = next else { break }; // stream terminated + + for info in batch.batch { + let entry = slices.entry((task_idx, info.partition)).or_default(); + entry.total_bytes = entry.total_bytes.saturating_add(info.byte_size); + entry.max_elapsed_ns = entry.max_elapsed_ns.max(info.time_mark_ns); + entry.msg_count += 1; + if !entry.done + && (info.eos + || info.max_memory_reached + || entry.msg_count >= MAX_MESSAGES_PER_SLICE) + { + entry.done = true; + done_count += 1; + } + } + if deadline.is_none() && !slices.is_empty() { + deadline = Some(Instant::now() + SAMPLING_WINDOW); + } + if done_count >= TARGET_DONE_SLICES { + break; + } + } + + // Each slice that observed enough data votes for a task count. The vote extrapolates the + // slice's observed velocity to the full producer (assumes all observed slices share the + // same velocity): + // slice_velocity = total_bytes / max_elapsed_ns (bytes/ns) + // stage_throughput = slice_velocity * num_slices_observed * 1e9 (bytes/sec) + // vote = ceil(stage_throughput / TARGET_BYTES_PER_SEC_PER_TASK) + let observed = slices.len().max(1) as u128; + let mut votes: Vec = slices + .values() + .filter_map(|s| { + if s.max_elapsed_ns == 0 || s.total_bytes == 0 { + return None; + } + let numerator = (s.total_bytes as u128) + .saturating_mul(1_000_000_000) + .saturating_mul(observed); + let denominator = + (s.max_elapsed_ns as u128).saturating_mul(TARGET_BYTES_PER_SEC_PER_TASK as u128); + Some(numerator.div_ceil(denominator).max(1)) + }) + .collect(); + + // Floor at the number of distinct producer task_idxs observed: never shrink the consumer + // stage below the producer's parallelism. Mirrors the static `CardinalityTaskCountStrategy` + // behavior where a consumer at least matches its producer. + let producer_task_floor = slices + .keys() + .map(|(t, _)| *t) + .collect::>() + .len() as u128; + + if votes.is_empty() { + return producer_task_floor.max(1) as usize; + } + + votes.sort_unstable(); + let median = votes[votes.len() / 2].max(producer_task_floor); + usize::try_from(median).unwrap_or(usize::MAX) +} diff --git a/src/coordinator/prepare_static_plan.rs b/src/coordinator/prepare_static_plan.rs index ff072694..5c1d6bbb 100644 --- a/src/coordinator/prepare_static_plan.rs +++ b/src/coordinator/prepare_static_plan.rs @@ -94,7 +94,7 @@ pub(super) fn prepare_static_plan( // Spawn a task that sends the subplan to the chosen URL. // There will be as many spawned tasks as workers. let (tx, worker_rx) = spawner.send_plan_task(Arc::clone(ctx), i, routed_url)?; - spawner.metrics_collection_task(i, worker_rx); + spawner.load_info_and_metrics_collection_task(i, worker_rx); spawner.work_unit_feed_task(Arc::clone(ctx), i, tx)?; } @@ -108,6 +108,7 @@ pub(super) fn prepare_static_plan( })?; Ok(PreparedPlan { head_stage: prepared.data, + final_plan: Arc::clone(base_plan), join_set, }) } diff --git a/src/coordinator/task_spawner.rs b/src/coordinator/task_spawner.rs index f960da45..8bc04753 100644 --- a/src/coordinator/task_spawner.rs +++ b/src/coordinator/task_spawner.rs @@ -235,17 +235,18 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { Ok((coordinator_to_worker_tx, worker_to_coordinator_rx)) } - pub(super) fn metrics_collection_task( + pub(super) fn load_info_and_metrics_collection_task( &mut self, task_i: usize, mut worker_to_coordinator_rx: UnboundedReceiver, - ) { + ) -> UnboundedReceiver { let task_key = TaskKey { query_id: serialize_uuid(&self.query_id), stage_id: self.stage_id as u64, task_number: task_i as u64, }; let task_metrics = self.task_metrics.clone(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); #[allow(clippy::disallowed_methods)] tokio::spawn(async move { while let Some(msg) = worker_to_coordinator_rx.recv().await { @@ -257,9 +258,13 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { task_metrics.insert(task_key.clone(), pre_order_metrics.metrics); } } + pb::worker_to_coordinator_msg::Inner::LoadInfoBatch(load_info_batch) => { + let _ = tx.send(load_info_batch); + } } } }); + rx } /// Launches the task that based on the different local [WorkUnitFeedExec] nodes, sends their diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index af1ff8ec..f0bbea41 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -582,6 +582,13 @@ pub trait DistributedExt: Sized { P: WorkUnitFeedProvider + 'static, P::WorkUnit: 'static, F: Fn(&T) -> Option<&WorkUnitFeed

> + Send + Sync + 'static; + + /// Dynamically allocates tasks to the different stages based on runtime statistics + /// collected during execution. + fn with_distributed_dynamic_task_count(self, enabled: bool) -> Result; + + /// Same as [DistributedExt::with_distributed_dynamic_task_count] but with an in-place mutation. + fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError>; } impl DistributedExt for SessionConfig { @@ -730,6 +737,12 @@ impl DistributedExt for SessionConfig { }) } + fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError> { + let d_cfg = DistributedConfig::from_config_options_mut(self.options_mut())?; + d_cfg.dynamic_task_count = enabled; + Ok(()) + } + delegate! { to self { #[call(set_distributed_option_extension)] @@ -812,6 +825,10 @@ impl DistributedExt for SessionConfig { P: WorkUnitFeedProvider + 'static, P::WorkUnit: 'static, F: Fn(&T) -> Option<&WorkUnitFeed

> + Send + Sync + 'static; + + #[call(set_distributed_dynamic_task_count)] + #[expr($?;Ok(self))] + fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; } } } @@ -923,6 +940,11 @@ impl DistributedExt for SessionStateBuilder { P: WorkUnitFeedProvider + 'static, P::WorkUnit: 'static, F: Fn(&T) -> Option<&WorkUnitFeed

> + Send + Sync + 'static; + + fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_dynamic_task_count)] + #[expr($?;Ok(self))] + fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; } } } @@ -1034,6 +1056,11 @@ impl DistributedExt for SessionState { P: WorkUnitFeedProvider + 'static, P::WorkUnit: 'static, F: Fn(&T) -> Option<&WorkUnitFeed

> + Send + Sync + 'static; + + fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_dynamic_task_count)] + #[expr($?;Ok(self))] + fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; } } } @@ -1145,6 +1172,11 @@ impl DistributedExt for SessionContext { P: WorkUnitFeedProvider + 'static, P::WorkUnit: 'static, F: Fn(&T) -> Option<&WorkUnitFeed

> + Send + Sync + 'static; + + fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError>; + #[call(set_distributed_dynamic_task_count)] + #[expr($?;Ok(self))] + fn with_distributed_dynamic_task_count(self, enabled: bool) -> Result; } } } diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs index 5b3500bc..0415136f 100644 --- a/src/distributed_planner/distributed_config.rs +++ b/src/distributed_planner/distributed_config.rs @@ -65,6 +65,8 @@ extensions_options! { /// budget will still be admitted (otherwise we would livelock), so the actual peak per /// connection is `worker_connection_buffer_budget_bytes + max_message_size`. pub worker_connection_buffer_budget_bytes: usize, default = 64 * 1024 * 1024 + /// TODO + pub dynamic_task_count: bool, default = false /// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to /// estimate how many tasks should be spawned for the [Stage] containing the leaf node. pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default() diff --git a/src/distributed_planner/distributed_query_planner.rs b/src/distributed_planner/distributed_query_planner.rs index 9593cd6a..1c83aa72 100644 --- a/src/distributed_planner/distributed_query_planner.rs +++ b/src/distributed_planner/distributed_query_planner.rs @@ -1,4 +1,6 @@ -use crate::distributed_planner::inject_network_boundaries::inject_network_boundaries; +use crate::distributed_planner::inject_network_boundaries::{ + CardinalityTaskCountStrategy, inject_network_boundaries, +}; use crate::distributed_planner::insert_broadcast::insert_broadcast_execs; use crate::distributed_planner::partial_reduce_below_network_shuffles::partial_reduce_below_network_shuffles; use crate::distributed_planner::prepare_network_boundaries::prepare_network_boundaries; @@ -91,7 +93,15 @@ impl QueryPlanner for DistributedQueryPlanner { plan = insert_broadcast_execs(plan, cfg)?; - plan = inject_network_boundaries(plan, cfg).await?; + if d_cfg.dynamic_task_count { + // The task count will be decided dynamically at execution time. + return Ok(Arc::new( + DistributedExec::new(plan).with_metrics_collection(d_cfg.collect_metrics), + )); + } + + // Compute per-node task counts and inject `Network*Exec` nodes at the stage boundaries. + plan = inject_network_boundaries(plan, CardinalityTaskCountStrategy, cfg).await?; plan = prepare_network_boundaries(plan)?; if !plan.exists(|plan| Ok(plan.is_network_boundary()))? { diff --git a/src/distributed_planner/inject_network_boundaries.rs b/src/distributed_planner/inject_network_boundaries.rs index a3af738f..d143810f 100644 --- a/src/distributed_planner/inject_network_boundaries.rs +++ b/src/distributed_planner/inject_network_boundaries.rs @@ -2,9 +2,10 @@ use crate::TaskCountAnnotation::{Desired, Maximum}; use crate::execution_plans::ChildrenIsolatorUnionExec; use crate::stage::LocalStage; use crate::{ - BroadcastExec, DistributedConfig, NetworkBoundaryExt, NetworkBroadcastExec, + BroadcastExec, DistributedConfig, NetworkBoundary, NetworkBoundaryExt, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, TaskCountAnnotation, TaskEstimator, }; +use async_trait::async_trait; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::{HashMap, Result, plan_err}; use datafusion::config::ConfigOptions; @@ -133,13 +134,15 @@ use uuid::Uuid; /// boundary injection, so the head stage is closed by running one final Phase 2 pass over /// the whole plan. This guarantees every node (including head-stage nodes that never sat /// directly above a boundary) has a task count recorded. -pub(super) async fn inject_network_boundaries( +pub(crate) async fn inject_network_boundaries( plan: Arc, + task_count_strategy: impl NetworkBoundaryBuilder + Send + Sync, cfg: &ConfigOptions, ) -> Result> { let ctx = Context { cfg, d_cfg: DistributedConfig::from_config_options(cfg)?, + stage_builder: &task_count_strategy, task_counts: &Mutex::new(HashMap::new()), query_id: Uuid::new_v4(), stage_id: &AtomicUsize::new(1), @@ -152,6 +155,7 @@ pub(super) async fn inject_network_boundaries( struct Context<'a> { cfg: &'a ConfigOptions, d_cfg: &'a DistributedConfig, + stage_builder: &'a (dyn NetworkBoundaryBuilder + Send + Sync), task_counts: &'a Mutex>, query_id: Uuid, stage_id: &'a AtomicUsize, @@ -206,6 +210,13 @@ impl<'a> Context<'a> { fn fetch_add_stage_id(&self) -> usize { self.stage_id.fetch_add(1, Ordering::Acquire) } + + async fn apply_stage_builder( + &self, + nb: Arc, + ) -> Result { + self.stage_builder.build(nb, self.cfg).await + } } /// Identity key for a plan node. The pointer is only used as a hash-map key, never dereferenced, @@ -289,16 +300,14 @@ async fn _inject_network_boundaries( // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let f = calculate_scale_factor(&plan, ctx); - let input_stage = LocalStage { + let plan = NetworkShuffleExec::from_stage(LocalStage { query_id: ctx.query_id, num: ctx.fetch_add_stage_id(), plan, tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkShuffleExec::from_stage(input_stage)); - let task_count = Desired((f * task_count.as_usize() as f64).ceil() as usize); - return Ok(ctx.plan_with_task_count(plan, task_count)); + }); + let result = ctx.apply_stage_builder(Arc::new(plan)).await?; + return Ok(ctx.with_task_count(result.network_boundary, result.task_count_above)); } // If the parent of the current node is either a `CoalescePartitionsExec` or a // `SortPreservingMergeExec`, a network boundary below it is necessary. @@ -317,31 +326,37 @@ async fn _inject_network_boundaries( // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let f = calculate_scale_factor(&plan, ctx); - let input_stage = LocalStage { + let plan = NetworkBroadcastExec::from_stage(LocalStage { query_id: ctx.query_id, num: ctx.fetch_add_stage_id(), plan, tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkBroadcastExec::from_stage(input_stage)); - let task_count = Desired((f * task_count.as_usize() as f64).ceil() as usize); - Ok(ctx.plan_with_task_count(plan, task_count)) + }); + let result = ctx.apply_stage_builder(Arc::new(plan)).await?; + return Ok(ctx.with_task_count(result.network_boundary, result.task_count_above)); } else { // The subtree below this point belongs to one stage. Propagate the chosen task // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let input_stage = LocalStage { - query_id: ctx.query_id, - num: ctx.fetch_add_stage_id(), - plan, - tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkCoalesceExec::from_stage(input_stage, 1)); + let plan = NetworkCoalesceExec::from_stage( + LocalStage { + query_id: ctx.query_id, + num: ctx.fetch_add_stage_id(), + plan, + tasks: task_count.as_usize(), + }, + 1, + ); + let result = ctx.apply_stage_builder(Arc::new(plan)).await?; + if !matches!(result.task_count_above, Maximum(1)) { + return plan_err!( + "A NetworkCoalesceExec must return exactly a Maximum(1) annotation above" + ); + } // The parent that triggered this branch is a `CoalescePartitionsExec` or // `SortPreservingMergeExec`, both of which fold all partitions into one — so the // stage above this boundary must run in exactly one task. - Ok(ctx.plan_with_task_count(plan, Maximum(1))) + Ok(ctx.with_task_count(result.network_boundary, Maximum(1))) }; } @@ -467,6 +482,37 @@ fn propagate_task_count_until_network_boundaries( } } +pub(crate) struct NetworkBoundaryBuilderResult { + pub(crate) task_count_above: TaskCountAnnotation, + pub(crate) network_boundary: Arc, +} + +#[async_trait] +pub(crate) trait NetworkBoundaryBuilder { + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result; +} + +#[async_trait] +impl NetworkBoundaryBuilder for T +where + T: Fn(Arc, &ConfigOptions) -> Result, + T: Send + Sync, + F: Future>, + F: Send, +{ + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result { + self(nb, cfg)?.await + } +} + /// Returns a multiplicative factor describing how the data volume changes between the bottom of /// `plan` (at a network boundary or a leaf) and `plan` itself. The walk descends into `plan`'s /// children, stops at any node that is itself a network boundary (returning `1.0` there — that @@ -500,24 +546,57 @@ fn propagate_task_count_until_network_boundaries( /// /// With `cardinality_task_count_factor = 1.5`, the example above yields `sf ≈ 0.44`. The /// boundary's recorded task count above this stage will be `ceil(T_producer × sf)`. -fn calculate_scale_factor(plan: &Arc, ctx: &Context) -> f64 { - if plan.is_network_boundary() { - return 1.0; - }; +pub(crate) struct CardinalityTaskCountStrategy; - let mut sf = None; - for plan in plan.children() { - sf = match sf { - None => Some(calculate_scale_factor(plan, ctx)), - Some(sf) => Some(sf.max(calculate_scale_factor(plan, ctx))), +#[async_trait] +impl NetworkBoundaryBuilder for CardinalityTaskCountStrategy { + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result { + if nb.as_any().is::() { + return Ok(NetworkBoundaryBuilderResult { + task_count_above: Maximum(1), + network_boundary: nb, + }); } - } + let d_cfg = DistributedConfig::from_config_options(cfg)?; + + fn calculate_scale_factor(plan: &Arc, d_cfg: &DistributedConfig) -> f64 { + if plan.is_network_boundary() { + return 1.0; + }; - let sf = sf.unwrap_or(1.0); - match plan.cardinality_effect() { - CardinalityEffect::LowerEqual => sf / ctx.d_cfg.cardinality_task_count_factor, - CardinalityEffect::GreaterEqual => sf * ctx.d_cfg.cardinality_task_count_factor, - _ => sf, + let mut sf = None; + for plan in plan.children() { + sf = match sf { + None => Some(calculate_scale_factor(plan, d_cfg)), + Some(sf) => Some(sf.max(calculate_scale_factor(plan, d_cfg))), + } + } + + let sf = sf.unwrap_or(1.0); + match plan.cardinality_effect() { + CardinalityEffect::LowerEqual => sf / d_cfg.cardinality_task_count_factor, + CardinalityEffect::GreaterEqual => sf * d_cfg.cardinality_task_count_factor, + _ => sf, + } + } + + let input_stage = nb.input_stage(); + let Some(input_plan) = input_stage.local_plan() else { + return plan_err!( + "input_stage plan needs to be in local mode for cardinality calculation" + ); + }; + + let f = calculate_scale_factor(input_plan, d_cfg); + + Ok(NetworkBoundaryBuilderResult { + task_count_above: Desired((f * input_stage.task_count() as f64).ceil() as usize), + network_boundary: nb, + }) } } @@ -1134,6 +1213,7 @@ mod tests { task_counts: &Mutex::new(HashMap::new()), query_id: Uuid::new_v4(), stage_id: &AtomicUsize::new(1), + stage_builder: &CardinalityTaskCountStrategy, }; let annotated = _inject_network_boundaries(plan, None, &ctx) diff --git a/src/distributed_planner/mod.rs b/src/distributed_planner/mod.rs index 174b9e8a..70aa6d97 100644 --- a/src/distributed_planner/mod.rs +++ b/src/distributed_planner/mod.rs @@ -9,7 +9,11 @@ mod session_state_builder_ext; mod task_estimator; pub use distributed_config::DistributedConfig; +pub(crate) use inject_network_boundaries::{ + NetworkBoundaryBuilderResult, inject_network_boundaries, +}; pub use network_boundary::{NetworkBoundary, NetworkBoundaryExt}; +pub(crate) use network_boundary::{network_boundary_inject_sampler, network_boundary_scale_input}; pub use session_state_builder_ext::SessionStateBuilderExt; pub use task_estimator::{TaskCountAnnotation, TaskEstimation, TaskEstimator, TaskRoutingContext}; pub(crate) use task_estimator::{get_distributed_task_estimator, set_distributed_task_estimator}; diff --git a/src/distributed_planner/network_boundary.rs b/src/distributed_planner/network_boundary.rs index fec385fc..98d01a41 100644 --- a/src/distributed_planner/network_boundary.rs +++ b/src/distributed_planner/network_boundary.rs @@ -11,7 +11,7 @@ pub trait NetworkBoundary: ExecutionPlan { /// information to perform any internal transformations necessary for distributed execution. /// /// Typically, [NetworkBoundary]s will use this call for transitioning from "Pending" to "ready". - fn with_input_stage(&self, input_stage: Stage) -> Result>; + fn with_input_stage(&self, input_stage: Stage) -> Result>; /// Returns the assigned input [Stage], if any. fn input_stage(&self) -> &Stage; @@ -76,3 +76,22 @@ pub(crate) fn network_boundary_scale_input( Ok(input) } + +pub(crate) fn network_boundary_inject_sampler( + input: Arc, +) -> Result> { + let transformed = NetworkShuffleExec::inject_sampler(Arc::clone(&input))?; + if transformed.transformed { + return Ok(transformed.data); + } + let transformed = NetworkBroadcastExec::inject_sampler(Arc::clone(&input))?; + if transformed.transformed { + return Ok(transformed.data); + } + let transformed = NetworkCoalesceExec::inject_sampler(Arc::clone(&input))?; + if transformed.transformed { + return Ok(transformed.data); + } + + Ok(input) +} diff --git a/src/distributed_planner/prepare_network_boundaries.rs b/src/distributed_planner/prepare_network_boundaries.rs index f00360ff..edb6d186 100644 --- a/src/distributed_planner/prepare_network_boundaries.rs +++ b/src/distributed_planner/prepare_network_boundaries.rs @@ -77,5 +77,5 @@ fn prepare( tasks: local_stage.tasks, })); *stage_id += 1; - nb + Ok(nb? as Arc) } diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index aa09d8be..63a7708f 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -6,6 +6,7 @@ mod network_broadcast; mod network_coalesce; mod network_shuffle; mod partition_isolator; +mod sampler; #[cfg(any(test, feature = "integration"))] pub mod benchmarks; @@ -17,3 +18,4 @@ pub use network_broadcast::NetworkBroadcastExec; pub use network_coalesce::NetworkCoalesceExec; pub use network_shuffle::NetworkShuffleExec; pub use partition_isolator::PartitionIsolatorExec; +pub use sampler::SamplerExec; diff --git a/src/execution_plans/network_broadcast.rs b/src/execution_plans/network_broadcast.rs index 60f1c9a1..4834a4e3 100644 --- a/src/execution_plans/network_broadcast.rs +++ b/src/execution_plans/network_broadcast.rs @@ -1,5 +1,6 @@ use crate::common::require_one_child; use crate::distributed_planner::NetworkBoundary; +use crate::execution_plans::SamplerExec; use crate::stage::{LocalStage, Stage}; use crate::worker::WorkerConnectionPool; use crate::{BroadcastExec, DistributedTaskContext}; @@ -139,6 +140,18 @@ impl NetworkBroadcastExec { )))) } + pub(crate) fn inject_sampler( + plan: Arc, + ) -> Result>> { + let Some(broadcast_exec) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let child = require_one_child(broadcast_exec.children())?; + plan.with_new_children(vec![Arc::new(SamplerExec::new(child))]) + .map(Transformed::yes) + } + pub(crate) fn from_stage(input_stage: LocalStage) -> Self { let input_partition_count = input_stage.plan.properties().partitioning.partition_count(); let properties = Arc::new( @@ -173,7 +186,7 @@ impl NetworkBroadcastExec { } impl NetworkBoundary for NetworkBroadcastExec { - fn with_input_stage(&self, input_stage: Stage) -> Result> { + fn with_input_stage(&self, input_stage: Stage) -> Result> { let mut self_clone = self.clone(); self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count()); self_clone.input_stage = input_stage; @@ -248,7 +261,8 @@ impl ExecutionPlan for NetworkBroadcastExec { }; let task_context = DistributedTaskContext::from_ctx(&context); - let off = self.properties.partitioning.partition_count() * task_context.task_index; + let out_partitions = self.properties.partitioning.partition_count(); + let off = out_partitions * task_context.task_index; let mut streams = Vec::with_capacity(self.input_stage.task_count()); for input_task_index in 0..self.input_stage.task_count() { @@ -256,6 +270,8 @@ impl ExecutionPlan for NetworkBroadcastExec { remote_stage, off..(off + self.properties.partitioning.partition_count()), input_task_index, + out_partitions, + task_context.task_count, &context, )?; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index ad082f09..e42e06a2 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -91,6 +91,14 @@ impl NetworkCoalesceExec { Ok(Transformed::no(plan)) } + /// Does nothing, but it's here for explicitly stating that this network boundary does support + /// input stage sampling. + pub(crate) fn inject_sampler( + plan: Arc, + ) -> Result>> { + Ok(Transformed::no(plan)) + } + pub(crate) fn from_stage(input_stage: LocalStage, consumer_tasks: usize) -> Self { // Each output task coalesces a group of input tasks. We size the output partition count // per output task based on the maximum group size, returning empty streams for tasks with @@ -148,7 +156,7 @@ impl NetworkBoundary for NetworkCoalesceExec { &self.input_stage } - fn with_input_stage(&self, input_stage: Stage) -> Result> { + fn with_input_stage(&self, input_stage: Stage) -> Result> { let mut self_clone = self.clone(); self_clone.properties = scale_partitioning_props(self_clone.properties(), |p| { p * input_stage.task_count() / self_clone.input_stage.task_count().max(1) @@ -224,10 +232,8 @@ impl ExecutionPlan for NetworkCoalesceExec { ); } - let partitions_per_task = self - .properties() - .partitioning - .partition_count() + let out_partitions = self.properties().partitioning.partition_count(); + let partitions_per_task = out_partitions .checked_div( self.input_stage .task_count() @@ -273,6 +279,8 @@ impl ExecutionPlan for NetworkCoalesceExec { remote_stage, 0..partitions_per_task, target_task, + out_partitions, + task_context.task_count, &context, )?; diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 00c02138..0b245913 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -1,4 +1,5 @@ use crate::common::require_one_child; +use crate::execution_plans::SamplerExec; use crate::execution_plans::common::scale_partitioning; use crate::stage::{LocalStage, Stage}; use crate::worker::WorkerConnectionPool; @@ -126,6 +127,18 @@ impl NetworkShuffleExec { Ok(Transformed::new(scaled, true, TreeNodeRecursion::Stop)) } + pub(crate) fn inject_sampler( + plan: Arc, + ) -> Result>> { + let Some(repartition_exec) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let child = require_one_child(repartition_exec.children())?; + plan.with_new_children(vec![Arc::new(SamplerExec::new(child))]) + .map(Transformed::yes) + } + pub(crate) fn from_stage(input_stage: LocalStage) -> Self { Self { properties: input_stage.plan.properties().clone(), @@ -161,7 +174,7 @@ impl NetworkBoundary for NetworkShuffleExec { &self.input_stage } - fn with_input_stage(&self, input_stage: Stage) -> Result> { + fn with_input_stage(&self, input_stage: Stage) -> Result> { let mut self_clone = self.clone(); self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count()); self_clone.input_stage = input_stage; @@ -226,7 +239,8 @@ impl ExecutionPlan for NetworkShuffleExec { }; let task_context = DistributedTaskContext::from_ctx(&context); - let off = self.properties.partitioning.partition_count() * task_context.task_index; + let out_partitions = self.properties.partitioning.partition_count(); + let off = out_partitions * task_context.task_index; let mut streams = Vec::with_capacity(remote_stage.workers.len()); for input_task_index in 0..remote_stage.workers.len() { @@ -234,6 +248,8 @@ impl ExecutionPlan for NetworkShuffleExec { remote_stage, off..(off + self.properties.partitioning.partition_count()), input_task_index, + out_partitions, + task_context.task_count, &context, )?; diff --git a/src/execution_plans/sampler.rs b/src/execution_plans/sampler.rs new file mode 100644 index 00000000..6fcef85a --- /dev/null +++ b/src/execution_plans/sampler.rs @@ -0,0 +1,322 @@ +use crate::common::require_one_child; +use crate::worker::generated::worker as pb; +use crate::{LatencyMetricExt, MaxLatencyMetric, P50LatencyMetric}; +use datafusion::arrow::array::RecordBatch; +use datafusion::common::runtime::SpawnedTask; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion::common::{Result, exec_err}; +use datafusion::execution::memory_pool::MemoryConsumer; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr_common::metrics::{Gauge, MetricsSet}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use futures::StreamExt; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; +use tokio::sync::Notify; +use tokio::sync::mpsc::UnboundedSender; +use tokio_stream::wrappers::UnboundedReceiverStream; + +#[derive(Debug)] +pub struct SamplerExec { + pub(crate) input: Arc, + pub(crate) metric_set: ExecutionPlanMetricsSet, + pub(crate) partition_samplers: Vec, +} + +/// Metrics that quantify how long the sampler held data in memory before the consumer +/// (real execution) attached, plus the peak buffer size reached. All metrics are shared +/// across the partition samplers; the latency metrics aggregate per-partition observations. +#[derive(Debug, Clone)] +pub(crate) struct SamplerExecMetrics { + /// Time from `kick_off()` (when the producer task starts pulling the input plan) until + /// `execute()` is invoked on this partition (when the consumer attaches). + kickoff_to_execution_p50: P50LatencyMetric, + kickoff_to_execution_max: MaxLatencyMetric, + /// Time from the first `LoadInfo` message being emitted until `execute()` is invoked. + /// Measures how long the coordinator sat between seeing the first sample and starting + /// the consumer. + first_load_info_to_execution_p50: P50LatencyMetric, + first_load_info_to_execution_max: MaxLatencyMetric, + /// Peak memory buffered by any partition sampler during the sampling phase. + max_mem_used: Gauge, +} + +impl SamplerExecMetrics { + fn new(metric_set: &ExecutionPlanMetricsSet) -> Self { + Self { + kickoff_to_execution_p50: MetricBuilder::new(metric_set) + .p50_latency("kickoff_to_execution_p50"), + kickoff_to_execution_max: MetricBuilder::new(metric_set) + .max_latency("kickoff_to_execution_max"), + first_load_info_to_execution_p50: MetricBuilder::new(metric_set) + .p50_latency("first_load_info_to_execution_p50"), + first_load_info_to_execution_max: MetricBuilder::new(metric_set) + .max_latency("first_load_info_to_execution_max"), + max_mem_used: MetricBuilder::new(metric_set).global_gauge("max_mem_used"), + } + } +} + +impl SamplerExec { + pub(crate) fn new(input: Arc) -> Self { + let metric_set = ExecutionPlanMetricsSet::new(); + let metrics = SamplerExecMetrics::new(&metric_set); + let partitions = input.properties().partitioning.partition_count(); + let mut samplers = Vec::with_capacity(partitions); + for i in 0..partitions { + samplers.push(PartitionSampler { + partition_idx: i, + input: Arc::clone(&input), + stream: Mutex::new(None), + execution_flag: Arc::new(AtomicBool::new(false)), + metrics: metrics.clone(), + kick_off_at: Arc::new(OnceLock::new()), + first_load_info_at: Arc::new(OnceLock::new()), + }); + } + Self { + input, + metric_set, + partition_samplers: samplers, + } + } + + pub(crate) fn kick_off_first_sampler( + plan: Arc, + tx: &UnboundedSender, + ctx: Arc, + ) -> Result<()> { + plan.apply(|plan| { + let Some(sampler) = plan.as_any().downcast_ref::() else { + return Ok(TreeNodeRecursion::Continue); + }; + for partition_sampler in &sampler.partition_samplers { + partition_sampler.kick_off(tx.clone(), Arc::clone(&ctx))?; + } + Ok(TreeNodeRecursion::Stop) + })?; + Ok(()) + } +} + +pub(crate) struct PartitionSampler { + partition_idx: usize, + input: Arc, + stream: Mutex>, + execution_flag: Arc, + + // Metrics state. + metrics: SamplerExecMetrics, + /// Set when `kick_off` is invoked. Used at `execute()` time to record how long the + /// sampler buffered data before the consumer attached. + kick_off_at: Arc>, + /// Set the first time the producer task emits a `LoadInfo`. Used at `execute()` time + /// to record the gap between the first sample and the consumer starting. + first_load_info_at: Arc>, +} + +impl Debug for PartitionSampler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PartitionSampler").finish() + } +} + +impl PartitionSampler { + fn start_stream(&self) -> Option { + self.execution_flag.store(true, Ordering::SeqCst); + let now = Instant::now(); + if let Some(t) = self.kick_off_at.get() { + let delay = now.saturating_duration_since(*t); + self.metrics.kickoff_to_execution_p50.add_duration(delay); + self.metrics.kickoff_to_execution_max.add_duration(delay); + } + if let Some(t) = self.first_load_info_at.get() { + let delay = now.saturating_duration_since(*t); + self.metrics + .first_load_info_to_execution_p50 + .add_duration(delay); + self.metrics + .first_load_info_to_execution_max + .add_duration(delay); + } + self.stream.lock().unwrap().take() + } + + fn kick_off( + &self, + sampling_tx: UnboundedSender, + ctx: Arc, + ) -> Result<()> { + let _ = self.kick_off_at.set(Instant::now()); + + let input = Arc::clone(&self.input); + let partition_idx = self.partition_idx; + let schema = input.schema(); + + let memory_reservation = Arc::new( + MemoryConsumer::new(format!("PartitionSampler[{partition_idx}]")) + .register(ctx.memory_pool()), + ); + let memory_reservation_for_consumer = Arc::clone(&memory_reservation); + + // Producer pauses when the buffer exceeds the budget; consumers wake it + // via this Notify after each shrink. Without the gate, the unbounded + // queue could grow without bound while the next stage hasn't started. + let mem_available_notify = Arc::new(Notify::new()); + let mem_available_notify_for_consumer = Arc::clone(&mem_available_notify); + + let execution_flag = Arc::clone(&self.execution_flag); + let max_mem_used = self.metrics.max_mem_used.clone(); + let first_load_info_at = Arc::clone(&self.first_load_info_at); + + // Execute the input synchronously so any setup error surfaces before we + // spawn the producer task. + let mut input_stream = input.execute(partition_idx, ctx)?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + + let task = SpawnedTask::spawn(async move { + let mut first_msg_ns = None; + let mut max_mem_signaled = false; + + while let Some(batch_or_err) = input_stream.next().await { + let batch = match batch_or_err { + Ok(b) => b, + Err(e) => { + let _ = tx.send(Err(e)); + return; + } + }; + let size = batch.get_array_memory_size(); + + // Backpressure: pause once the buffer exceeds the budget. The + // consumer (real execution) wakes us after each batch it drains. + while memory_reservation.size() >= PARTITION_SAMPLER_BUDGET_BYTES { + mem_available_notify.notified().await; + } + + memory_reservation.grow(size); + max_mem_used.set_max(memory_reservation.size()); + + if !execution_flag.load(Ordering::Relaxed) { + let time = Instant::now(); + let time_mark_ns = match first_msg_ns { + Some(t) => time - t, + None => { + first_msg_ns = Some(time); + let _ = first_load_info_at.set(time); + Duration::default() + } + }; + let max_memory_reached = !max_mem_signaled + && memory_reservation.size() >= PARTITION_SAMPLER_BUDGET_BYTES; + if max_memory_reached { + max_mem_signaled = true; + } + let _ = sampling_tx.send(pb::LoadInfo { + partition: partition_idx as u64, + row_count: batch.num_rows() as u64, + byte_size: size as u64, + time_mark_ns: time_mark_ns.as_nanos() as u64, + eos: false, + max_memory_reached, + }); + } + + if tx.send(Ok(batch)).is_err() { + return; + } + } + + // End of input: if real execution hasn't started yet, tell the + // coordinator we observed the entire stream. + if !execution_flag.load(Ordering::Relaxed) { + let time_mark_ns = first_msg_ns.map(|t| Instant::now() - t).unwrap_or_default(); + let _ = sampling_tx.send(pb::LoadInfo { + partition: partition_idx as u64, + row_count: 0, + byte_size: 0, + time_mark_ns: time_mark_ns.as_nanos() as u64, + eos: true, + max_memory_reached: false, + }); + } + }); + + let stream = UnboundedReceiverStream::new(rx).map(move |result| { + let _ = &task; // keep the task alive as long as the stream is alive. + if let Ok(batch) = &result { + memory_reservation_for_consumer.shrink(batch.get_array_memory_size()); + mem_available_notify_for_consumer.notify_one(); + } + result + }); + let stream = RecordBatchStreamAdapter::new(schema, stream); + + self.stream + .lock() + .expect("poisoned lock") + .replace(Box::pin(stream)); + + Ok(()) + } +} + +/// Soft byte budget the partition sampler will buffer before pausing the +/// producer. Once exceeded, a `max_memory_reached` LoadInfo is emitted once, +/// signaling to the coordinator that the producer can sustain at least this +/// much in-flight data. +/// TODO: make this configurable via DistributedConfig. +const PARTITION_SAMPLER_BUDGET_BYTES: usize = 32 * 1024 * 1024; + +impl DisplayAs for SamplerExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "SamplerExec") + } +} + +impl ExecutionPlan for SamplerExec { + fn name(&self) -> &str { + "SamplerExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new(require_one_child(children)?))) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + let Some(stream) = self.partition_samplers[partition].start_stream() else { + return exec_err!("SamplerExec[{partition}] was not kicked off"); + }; + Ok(stream) + } + + fn metrics(&self) -> Option { + Some(self.metric_set.clone_inner()) + } +} diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index 4b442f0a..59853b52 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -16,7 +16,6 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::internal_err; use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet}; use std::sync::Arc; -use std::vec; /// Format to use when displaying metrics for a distributed plan. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -49,27 +48,23 @@ pub async fn rewrite_distributed_plan_with_metrics( return Ok(plan); }; - // Check that the plan was executed before waiting — if not, prepared_plan() returns an - // error immediately rather than waiting forever for metrics that will never arrive. - let prepared = distributed_exec.prepared_plan()?; - distributed_exec.wait_for_metrics().await; let Some(metrics_collection) = distributed_exec.task_metrics.clone() else { return Ok(plan); }; - let task_metrics = collect_plan_metrics(&prepared)?; + let head_stage = distributed_exec.head_stage()?; + let task_metrics = collect_plan_metrics(&head_stage)?; // Rewrite the DistributedExec's child plan with metrics. let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics( format.to_rewrite_ctx(0), // Task id is 0 for the DistributedExec plan - plan.children()[0].clone(), + distributed_exec.prepared_plan()?, task_metrics, )?; - let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?; - let transformed = plan.transform_down(|plan| { + let transformed = dist_exec_plan_with_metrics.transform_down(|plan| { // Transform all stages using NetworkShuffleExec and NetworkCoalesceExec as barriers. if let Some(network_boundary) = plan.as_network_boundary() { let Stage::Local(stage) = network_boundary.input_stage() else { @@ -92,7 +87,7 @@ pub async fn rewrite_distributed_plan_with_metrics( Ok(Transformed::no(plan)) })?; - Ok(transformed.data) + plan.with_new_children(vec![transformed.data]) } /// Extra information for rewriting local plans. diff --git a/src/observability/service.rs b/src/observability/service.rs index 06827dca..7b7c5473 100644 --- a/src/observability/service.rs +++ b/src/observability/service.rs @@ -96,7 +96,7 @@ impl ObservabilityService for ObservabilityServiceImpl { let total_partitions = task_data.total_partitions() as u64; let remaining = task_data.num_partitions_remaining() as u64; let completed_partitions = total_partitions.saturating_sub(remaining); - let output_rows = output_rows_from_plan(&task_data.plan); + let output_rows = output_rows_from_plan(&task_data.base_plan); tasks.push(TaskProgress { task_key: Some((*internal_key).clone()), diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index b529c509..6d9ae2ec 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -1,7 +1,8 @@ use super::get_distributed_user_codecs; -use crate::common::{deserialize_uuid, serialize_uuid}; +use crate::common::{deserialize_uuid, require_one_child, serialize_uuid}; use crate::execution_plans::{ BroadcastExec, ChildrenIsolatorUnionExec, NetworkBroadcastExec, NetworkCoalesceExec, + SamplerExec, }; use crate::stage::{LocalStage, RemoteStage, Stage}; use crate::worker::WorkerConnectionPool; @@ -238,6 +239,9 @@ impl PhysicalExtensionCodec for DistributedCodec { .collect(), })) } + DistributedExecNode::Sampler(SamplerExecProto {}) => { + Ok(Arc::new(SamplerExec::new(require_one_child(inputs)?))) + } } } @@ -356,6 +360,14 @@ impl PhysicalExtensionCodec for DistributedCodec { node: Some(DistributedExecNode::ChildrenIsolatorUnion(inner)), }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) + } else if let Some(_node) = node.as_any().downcast_ref::() { + let inner = SamplerExecProto {}; + + let wrapper = DistributedExecProto { + node: Some(DistributedExecNode::Sampler(inner)), + }; + wrapper.encode(buf).map_err(|e| proto_error(format!("{e}"))) } else { Err(proto_error(format!("Unexpected plan {}", node.name()))) @@ -387,7 +399,7 @@ pub struct ExecutionTaskProto { #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedExecProto { - #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6")] + #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6, 7")] pub node: Option, } @@ -405,6 +417,8 @@ pub enum DistributedExecNode { NetworkBroadcast(NetworkBroadcastExecProto), #[prost(message, tag = "6")] Broadcast(BroadcastExecProto), + #[prost(message, tag = "7")] + Sampler(SamplerExecProto), } #[derive(Clone, PartialEq, ::prost::Message)] @@ -513,6 +527,9 @@ pub struct BroadcastExecProto { pub consumer_task_count: u64, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SamplerExecProto {} + fn new_network_broadcast_exec( partitioning: Partitioning, schema: SchemaRef, diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index 57aa698a..6664418d 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -21,7 +21,7 @@ pub mod coordinator_to_worker_msg { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct WorkerToCoordinatorMsg { - #[prost(oneof = "worker_to_coordinator_msg::Inner", tags = "1")] + #[prost(oneof = "worker_to_coordinator_msg::Inner", tags = "1, 2")] pub inner: ::core::option::Option, } /// Nested message and enum types in `WorkerToCoordinatorMsg`. @@ -34,6 +34,10 @@ pub mod worker_to_coordinator_msg { /// metrics\[i\] is the set of metrics for plan node i in pre-order traversal order. #[prost(message, tag = "1")] TaskMetrics(super::PreOrderTaskMetrics), + /// Load information reported by a task. This information is used for dynamically + /// sizing the number of workers involved in a query. + #[prost(message, tag = "2")] + LoadInfoBatch(super::LoadInfoBatch), } } /// Metrics for a single task's plan nodes in pre-order traversal order. @@ -44,6 +48,31 @@ pub struct PreOrderTaskMetrics { #[prost(message, repeated, tag = "1")] pub metrics: ::prost::alloc::vec::Vec, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct LoadInfoBatch { + #[prost(message, repeated, tag = "1")] + pub batch: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct LoadInfo { + #[prost(uint64, tag = "1")] + pub partition: u64, + #[prost(uint64, tag = "2")] + pub row_count: u64, + #[prost(uint64, tag = "3")] + pub byte_size: u64, + #[prost(bool, tag = "4")] + pub eos: bool, + #[prost(uint64, tag = "5")] + pub time_mark_ns: u64, + /// True on the first sample emitted after the partition's sampler buffer + /// exceeded its configured byte budget. Consumers (the coordinator's + /// calculate_task_count) treat this as a "saturation" signal that the + /// producer can produce at least this much data, and can decide the next + /// stage's task count without waiting for further samples. + #[prost(bool, tag = "6")] + pub max_memory_reached: bool, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct GetWorkerInfoRequest {} #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -110,6 +139,12 @@ pub struct ExecuteTaskRequest { /// The end of the partition range of the specified task that is going to be executed. #[prost(uint64, tag = "3")] pub target_partition_end: u64, + /// The amount of partitions per task that are going to consume from this task. + #[prost(uint64, tag = "4")] + pub consumer_partitions: u64, + /// The amount of tasks that are going to consume from this task. + #[prost(uint64, tag = "5")] + pub consumer_task_count: u64, } /// A key that uniquely identifies a task in a query. #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] diff --git a/src/worker/impl_coordinator_channel.rs b/src/worker/impl_coordinator_channel.rs index dfc5da33..3d101ea0 100644 --- a/src/worker/impl_coordinator_channel.rs +++ b/src/worker/impl_coordinator_channel.rs @@ -1,4 +1,5 @@ use crate::common::deserialize_uuid; +use crate::execution_plans::SamplerExec; use crate::work_unit_feed::RemoteWorkUnitFeedRegistry; use crate::worker::LocalWorkerContext; use crate::worker::generated::worker::coordinator_to_worker_msg::Inner; @@ -17,9 +18,10 @@ use datafusion::prelude::SessionConfig; use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::{FutureExt, StreamExt}; -use std::sync::Arc; use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, OnceLock}; use tokio::sync::oneshot; +use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Request, Response, Status, Streaming}; use url::Url; @@ -54,6 +56,7 @@ impl Worker { } let (metrics_tx, metrics_rx) = oneshot::channel(); + let (load_info_tx, load_info_rx) = tokio::sync::mpsc::unbounded_channel(); let task_data = || async { let headers = grpc_headers.into_headers(); @@ -97,11 +100,17 @@ impl Worker { for hook in self.hooks.on_plan.iter() { plan = hook(plan) } + SamplerExec::kick_off_first_sampler( + Arc::clone(&plan), + &load_info_tx, + Arc::clone(&task_ctx), + )?; // Initialize partition count to the number of partitions in the stage let total_partitions = plan.properties().partitioning.partition_count(); Ok::<_, DataFusionError>(TaskData { - plan, + base_plan: plan, + scaled_up_plan: Arc::new(OnceLock::new()), task_ctx, num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)), metrics_tx: match collect_metrics { @@ -137,6 +146,15 @@ impl Worker { } }); + let load_info_stream = UnboundedReceiverStream::new(load_info_rx); + let load_info_stream = load_info_stream.map(|load_info| WorkerToCoordinatorMsg { + inner: Some(worker_to_coordinator_msg::Inner::LoadInfoBatch( + crate::worker::generated::worker::LoadInfoBatch { + batch: vec![load_info], + }, + )), + }); + // Stream back the metrics once the task finishes executing. // The oneshot receiver resolves when impl_execute_task sends the collected // metrics after all partitions have finished or been dropped. @@ -149,7 +167,12 @@ impl Worker { Err(_) => None, // channel dropped without sending any message } }); - Ok(Response::new(metrics_stream.map(Ok).boxed())) + + Ok(Response::new( + futures::stream::select(load_info_stream, metrics_stream) + .map(Ok) + .boxed(), + )) } } diff --git a/src/worker/impl_execute_task.rs b/src/worker/impl_execute_task.rs index 13e3a4a8..f9db2f1a 100644 --- a/src/worker/impl_execute_task.rs +++ b/src/worker/impl_execute_task.rs @@ -59,7 +59,10 @@ pub(crate) async fn execute_local_task( .map_err(|e| exec_datafusion_err!("Worker::execute_task timed-out while waiting for the plan to be set by the coordinator. ({e})"))? .map_err(DataFusionError::Shared)?; - let plan = task_data.plan; + let plan = task_data.scaled_up_plan( + body.consumer_partitions as usize, + body.consumer_task_count as usize, + )?; let task_ctx = task_data.task_ctx; let d_cfg = DistributedConfig::from_config_options(task_ctx.session_config().options())?; diff --git a/src/worker/mod.rs b/src/worker/mod.rs index fc65575f..e89921fc 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -17,6 +17,5 @@ pub use session_builder::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, WorkerQueryContext, WorkerSessionBuilder, }; -pub use worker_service::Worker; - pub use task_data::TaskData; +pub use worker_service::Worker; diff --git a/src/worker/task_data.rs b/src/worker/task_data.rs index 1454ecd9..fe052419 100644 --- a/src/worker/task_data.rs +++ b/src/worker/task_data.rs @@ -1,8 +1,11 @@ -use crate::worker::generated::worker::PreOrderTaskMetrics; +use crate::common::OnceLockResult; +use crate::distributed_planner::network_boundary_scale_input; +use crate::worker::generated::worker as pb; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::TaskContext; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use std::sync::Arc; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::oneshot; #[derive(Clone, Debug)] @@ -11,8 +14,8 @@ use tokio::sync::oneshot; pub struct TaskData { /// Task context suitable for execute different partitions from the same task. pub(super) task_ctx: Arc, - /// Plan to be executed. - pub(crate) plan: Arc, + pub(crate) base_plan: Arc, + pub(crate) scaled_up_plan: Arc>>, /// `num_partitions_remaining` is initialized to the total number of partitions in the task (not /// only tasks in the partition group). This is decremented for each request to the endpoint /// for this task. Once this count is zero, the task is likely complete. The task may not be @@ -22,18 +25,50 @@ pub struct TaskData { /// Sender half of the metrics channel. `impl_execute_task` takes this (via `Option::take`) /// once all partitions have finished or been dropped, sending the collected metrics back to /// the coordinator through the `CoordinatorChannel` side channel. - pub(super) metrics_tx: Arc>>>, + pub(super) metrics_tx: Arc>>>, } impl TaskData { /// Returns the number of partitions remaining to be processed. pub(crate) fn num_partitions_remaining(&self) -> usize { - self.num_partitions_remaining - .load(std::sync::atomic::Ordering::Relaxed) + self.num_partitions_remaining.load(Ordering::SeqCst) } /// Returns the total number of partitions in this task. pub(crate) fn total_partitions(&self) -> usize { - self.plan.properties().partitioning.partition_count() + match self.scaled_up_plan.get() { + Some(Ok(plan)) => plan.output_partitioning().partition_count(), + _ => self + .base_plan + .properties() + .output_partitioning() + .partition_count(), + } + } + + pub(crate) fn scaled_up_plan( + &self, + consumer_partitions: usize, + consumer_task_count: usize, + ) -> Result> { + let result = self.scaled_up_plan.get_or_init(|| { + let scaled_up = match network_boundary_scale_input( + Arc::clone(&self.base_plan), + consumer_partitions, + consumer_task_count, + ) { + Ok(scaled_up) => scaled_up, + Err(err) => return Err(Arc::new(err)), + }; + + let partition_count = scaled_up.output_partitioning().partition_count(); + self.num_partitions_remaining + .store(partition_count, Ordering::SeqCst); + Ok(scaled_up) + }); + match result { + Ok(plan) => Ok(Arc::clone(plan)), + Err(err) => Err(DataFusionError::Shared(Arc::clone(err))), + } } } diff --git a/src/worker/test_utils/worker_handles.rs b/src/worker/test_utils/worker_handles.rs index 83a996c0..00aba096 100644 --- a/src/worker/test_utils/worker_handles.rs +++ b/src/worker/test_utils/worker_handles.rs @@ -214,7 +214,8 @@ pub async fn register_plan_on_worker( swmr_task_data .write(Ok(TaskData { task_ctx, - plan, + base_plan: plan, + scaled_up_plan: Default::default(), num_partitions_remaining: Arc::new(AtomicUsize::new(partition_count)), metrics_tx: Arc::new(std::sync::Mutex::new(Some(metrics_tx))), })) diff --git a/src/worker/worker.proto b/src/worker/worker.proto index 6372f31b..2ebeb88d 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -31,6 +31,10 @@ message WorkerToCoordinatorMsg { // ensuring metrics are never lost due to early stream termination. // metrics[i] is the set of metrics for plan node i in pre-order traversal order. PreOrderTaskMetrics task_metrics = 1; + + // Load information reported by a task. This information is used for dynamically + // sizing the number of workers involved in a query. + LoadInfoBatch load_info_batch = 2; } } @@ -41,6 +45,24 @@ message PreOrderTaskMetrics { repeated MetricsSet metrics = 1; } +message LoadInfoBatch { + repeated LoadInfo batch = 1; +} + +message LoadInfo { + uint64 partition = 1; + uint64 row_count = 2; + uint64 byte_size = 3; + bool eos = 4; + uint64 time_mark_ns = 5; + // True on the first sample emitted after the partition's sampler buffer + // exceeded its configured byte budget. Consumers (the coordinator's + // calculate_task_count) treat this as a "saturation" signal that the + // producer can produce at least this much data, and can decide the next + // stage's task count without waiting for further samples. + bool max_memory_reached = 6; +} + message GetWorkerInfoRequest {} message GetWorkerInfoResponse { @@ -88,6 +110,10 @@ message ExecuteTaskRequest { uint64 target_partition_start = 2; // The end of the partition range of the specified task that is going to be executed. uint64 target_partition_end = 3; + // The amount of partitions per task that are going to consume from this task. + uint64 consumer_partitions = 4; + // The amount of tasks that are going to consume from this task. + uint64 consumer_task_count = 5; } // A key that uniquely identifies a task in a query. diff --git a/src/worker/worker_connection_pool.rs b/src/worker/worker_connection_pool.rs index d905383b..7afe194a 100644 --- a/src/worker/worker_connection_pool.rs +++ b/src/worker/worker_connection_pool.rs @@ -88,6 +88,8 @@ impl WorkerConnectionPool { input_stage: &RemoteStage, target_partitions: Range, target_task: usize, + consumer_partitions: usize, + consumer_task_count: usize, ctx: &Arc, ) -> Result<&(dyn WorkerConnection + Sync + Send)> { let Some(worker_connection) = self.connections.get(target_task) else { @@ -110,6 +112,8 @@ impl WorkerConnectionPool { input_stage, target_partitions, target_task, + consumer_partitions, + consumer_task_count, lw_ctx, &self.metrics, )) as Box<_>) @@ -119,6 +123,8 @@ impl WorkerConnectionPool { input_stage, target_partitions, target_task, + consumer_partitions, + consumer_task_count, ctx, &self.metrics, ) @@ -174,6 +180,8 @@ impl RemoteWorkerConnection { input_stage: &RemoteStage, target_partition_range: Range, target_task: usize, + consumer_partitions: usize, + consumer_task_count: usize, ctx: &Arc, metrics: &ExecutionPlanMetricsSet, ) -> Result { @@ -223,6 +231,8 @@ impl RemoteWorkerConnection { stage_id: input_stage.num as u64, task_number: target_task as u64, }), + consumer_partitions: consumer_partitions as u64, + consumer_task_count: consumer_task_count as u64, }, ); @@ -427,6 +437,8 @@ impl LocalWorkerConnection { input_stage: &RemoteStage, target_partition_range: Range, target_task: usize, + consumer_partitions: usize, + consumer_task_count: usize, lw_ctx: Arc, metrics: &ExecutionPlanMetricsSet, ) -> Self { @@ -443,6 +455,8 @@ impl LocalWorkerConnection { }), target_partition_start: target_partition_range.start as u64, target_partition_end: target_partition_range.end as u64, + consumer_partitions: consumer_partitions as u64, + consumer_task_count: consumer_task_count as u64, }, } } diff --git a/tests/metrics_collection.rs b/tests/metrics_collection.rs index 1bc283c3..d87b5528 100644 --- a/tests/metrics_collection.rs +++ b/tests/metrics_collection.rs @@ -6,6 +6,7 @@ mod tests { use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion::execution::SessionState; use datafusion::physical_plan::display::DisplayableExecutionPlan; + use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ExecutionPlan, execute_stream}; use datafusion::prelude::SessionContext; use datafusion_distributed::test_utils::localhost::start_localhost_context; @@ -256,6 +257,37 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_metrics_collection_dynamic() -> Result<(), Box> { + let (mut d_ctx, _guard, _) = start_localhost_context(3, DefaultSessionBuilder).await; + d_ctx.set_distributed_dynamic_task_count(true)?; + + let query = + r#"SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)"#; + + let s_ctx = SessionContext::default(); + let (s_physical, mut d_physical) = execute(&s_ctx, &d_ctx, query).await?; + d_physical = rewrite_with_metrics(d_physical, DistributedMetricsFormat::Aggregated).await; + println!("{}", display_plan_ascii(s_physical.as_ref(), true)); + println!("{}", display_plan_ascii(d_physical.as_ref(), true)); + + assert_metrics_equal::( + ["output_rows", "output_bytes"], + &s_physical, + &d_physical, + 0, + ); + + assert_metrics_equal::( + ["output_rows", "output_bytes"], + &s_physical, + &d_physical, + 0, + ); + + Ok(()) + } + /// Looks for an [ExecutionPlan] that matches the provided type parameter `T` in /// both root nodes and compares its metrics. /// There might be more than one, so `index` determines which one is compared. From 72cc30ba802badb9025b44cbc81b6773e4f797f3 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 11 May 2026 03:27:30 -0400 Subject: [PATCH 2/7] Worker-side metrics collection --- benchmarks/cdk/bin/datafusion-bench.ts | 5 + benchmarks/src/run.rs | 15 +- src/coordinator/prepare_dynamic_plan.rs | 127 +++-------- src/coordinator/task_spawner.rs | 6 +- src/distributed_ext.rs | 42 ++++ src/distributed_planner/distributed_config.rs | 2 + src/execution_plans/sampler.rs | 165 +++++++------- src/worker/generated/worker.rs | 209 +++++++++++------- src/worker/impl_coordinator_channel.rs | 6 +- src/worker/worker.proto | 18 +- 10 files changed, 299 insertions(+), 296 deletions(-) diff --git a/benchmarks/cdk/bin/datafusion-bench.ts b/benchmarks/cdk/bin/datafusion-bench.ts index 5aa9aa0c..338eed1d 100644 --- a/benchmarks/cdk/bin/datafusion-bench.ts +++ b/benchmarks/cdk/bin/datafusion-bench.ts @@ -25,6 +25,7 @@ async function main() { .option('--repartition-file-min-size ', 'repartition_file_min_size DF option', '10485760' /* upstream default */) .option('--target-partitions ', 'target_partitions DF option', '8') .option('--dynamic ', 'Use the dynamic task count assigner', 'false') + .option('--bytes-per-partition-per-second ', 'Target throughput in bytes per partition per second for the dynamic task count allocator', `${256 * 1024 * 1024}`) .option('--queries ', 'Specific queries to run', undefined) .option('--debug ', 'Print the generated plans to stdout') .option('--warmup ', 'Perform a warmup query before the benchmarks', 'true') @@ -48,6 +49,7 @@ async function main() { const broadcastJoins = options.broadcastJoins === 'true' || options.broadcastJoins === 1 const partialReduce = options.partialReduce === 'true' || options.partialReduce === 1 const dynamicTaskCount = options.dynamic === 'true' || options.dynamic === 1 + const bytesPerPartitionPerSecond = parseInt(options.bytesPerPartitionPerSecond) const debug = options.debug === true || options.debug === 'true' || options.debug === 1 const warmup = options.warmup === true || options.warmup === 'true' || options.warmup === 1 @@ -62,6 +64,7 @@ async function main() { broadcastJoins, partialReduce, dynamicTaskCount, + bytesPerPartitionPerSecond, maxTasksPerStage, repartitionFileMinSize, targetPartitions @@ -101,6 +104,7 @@ class DataFusionRunner implements BenchmarkRunner { broadcastJoins: boolean; partialReduce: boolean; dynamicTaskCount: boolean; + bytesPerPartitionPerSecond: number; maxTasksPerStage: number; repartitionFileMinSize: number; targetPartitions: number; @@ -181,6 +185,7 @@ class DataFusionRunner implements BenchmarkRunner { SET distributed.broadcast_joins=${this.options.broadcastJoins}; SET distributed.partial_reduce=${this.options.partialReduce}; SET distributed.dynamic_task_count=${this.options.dynamicTaskCount}; + SET distributed.bytes_per_partition_per_second=${this.options.bytesPerPartitionPerSecond}; SET distributed.max_tasks_per_stage=${this.options.maxTasksPerStage}; SET datafusion.optimizer.repartition_file_min_size=${this.options.repartitionFileMinSize}; SET datafusion.execution.target_partitions=${this.options.targetPartitions}; diff --git a/benchmarks/src/run.rs b/benchmarks/src/run.rs index eee28117..aefbbd63 100644 --- a/benchmarks/src/run.rs +++ b/benchmarks/src/run.rs @@ -118,6 +118,10 @@ pub struct RunOpt { #[structopt(long = "dynamic")] dynamic: bool, + /// Amount of bytes per second each partition is expected to handle during dynamic execution + #[structopt(long = "bps")] + bytes_per_partition_per_second: Option, + /// Activate debug mode to see more details #[structopt(short, long)] debug: bool, @@ -177,7 +181,7 @@ impl RunOpt { } async fn run_local(self) -> Result<()> { - let state = SessionStateBuilder::new() + let mut builder = SessionStateBuilder::new() .with_default_features() .with_config(self.config()?) .with_distributed_worker_resolver(LocalHostWorkerResolver::new(self.workers.clone())) @@ -198,9 +202,12 @@ impl RunOpt { .with_distributed_broadcast_joins(self.broadcast_joins)? .with_distributed_metrics_collection(self.collect_metrics)? .with_distributed_max_tasks_per_stage(self.max_tasks_per_stage)? - .with_distributed_dynamic_task_count(self.dynamic)? - .build(); - let ctx = SessionContext::new_with_state(state); + .with_distributed_dynamic_task_count(self.dynamic)?; + if let Some(v) = self.bytes_per_partition_per_second { + builder.set_distributed_bytes_per_partition_per_second(v)?; + } + + let ctx = SessionContext::new_with_state(builder.build()); register_tables(&ctx, &self.get_path()?).await?; println!("Running benchmarks with the following options: {self:?}"); diff --git a/src/coordinator/prepare_dynamic_plan.rs b/src/coordinator/prepare_dynamic_plan.rs index 2a210b24..395c5222 100644 --- a/src/coordinator/prepare_dynamic_plan.rs +++ b/src/coordinator/prepare_dynamic_plan.rs @@ -10,11 +10,10 @@ use crate::distributed_planner::{ use crate::stage::{LocalStage, RemoteStage}; use crate::worker::generated::worker as pb; use crate::{ - DistributedCodec, NetworkBoundary, NetworkBoundaryExt, NetworkCoalesceExec, Stage, - TaskCountAnnotation, TaskRoutingContext, get_distributed_worker_resolver, + DistributedCodec, DistributedConfig, NetworkBoundary, NetworkBoundaryExt, NetworkCoalesceExec, + Stage, TaskCountAnnotation, TaskRoutingContext, get_distributed_worker_resolver, }; use dashmap::DashMap; -use datafusion::common::instant::Instant; use datafusion::common::runtime::JoinSet; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion::common::{Result, exec_err}; @@ -24,11 +23,9 @@ use datafusion::physical_expr_common::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::ExecutionPlan; use futures::{Stream, StreamExt}; use rand::Rng; -use std::collections::HashMap; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::SeqCst; use std::sync::{Arc, Mutex}; -use std::time::Duration; use tokio_stream::wrappers::UnboundedReceiverStream; use url::Url; @@ -46,7 +43,8 @@ pub(super) async fn prepare_dynamic_plan( let head_stage = inject_network_boundaries( Arc::clone(base_plan), - |nb: Arc, _cfg: &ConfigOptions| { + |nb: Arc, cfg: &ConfigOptions| { + let d_cfg = DistributedConfig::from_config_options(cfg)?; let worker_resolver = get_distributed_worker_resolver(ctx.session_config())?; let codec = DistributedCodec::new_combined_with_user(ctx.session_config()); let task_estimator = get_distributed_task_estimator(ctx.session_config())?; @@ -139,12 +137,22 @@ pub(super) async fn prepare_dynamic_plan( }))?; let load_info_stream = futures::stream::select_all(load_info_rxs); + let partitions_per_task = nb.properties().partitioning.partition_count(); + let partitions_remaining = partitions_per_task * input_stage.tasks; + let bytes_per_partition_per_second = d_cfg.bytes_per_partition_per_second; Ok(async move { let task_count_above = if nb.as_any().is::() { TaskCountAnnotation::Maximum(1) } else { - TaskCountAnnotation::Desired(calculate_task_count(load_info_stream).await) + let necessary_partitions = calculate_necessary_partitions( + load_info_stream, + partitions_remaining, + bytes_per_partition_per_second, + ) + .await; + let expected_tasks = necessary_partitions.div_ceil(partitions_per_task); + TaskCountAnnotation::Desired(expected_tasks) }; Ok(NetworkBoundaryBuilderResult { task_count_above, @@ -230,104 +238,31 @@ fn get_child_stages_urls( /// Returns early as soon as every observed slice has emitted a terminating signal /// (`max_memory_reached` or `eos`); otherwise exits at the sampling deadline. Returns 1 if /// no slice could compute a usable velocity. -async fn calculate_task_count( - mut load_info_stream: impl Stream + Unpin, +async fn calculate_necessary_partitions( + mut load_info_stream: impl Stream + Unpin, + mut partitions_remaining: usize, + bytes_per_partition_per_second: usize, ) -> usize { - /// Target sustained throughput per downstream task. The next stage is sized so each - /// task is expected to absorb roughly this many bytes per second of producer output. - const TARGET_BYTES_PER_SEC_PER_TASK: u64 = 64 * 1024 * 1024; - /// Per-`(task_idx, partition)` cap on buffered `LoadInfo` messages. Once a slice has - /// produced this many messages it is considered done — enough information has been seen - /// to estimate its velocity. - const MAX_MESSAGES_PER_SLICE: usize = 2; /// Minimum number of slices that must reach the "done" state before voting. A slice is /// done when it signals `eos`, `max_memory_reached`, or hits `MAX_MESSAGES_PER_SLICE`. - const TARGET_DONE_SLICES: usize = 2; - /// Wall-clock safety net measured from the first received `LoadInfo`. If neither - /// `TARGET_DONE_SLICES` nor full slice coverage is reached within this window, vote with - /// whatever has been observed. Prevents deadlock if a stage has fewer slices than - /// `TARGET_DONE_SLICES` but new slices are still appearing slowly. - const SAMPLING_WINDOW: Duration = Duration::from_millis(25); + const TARGET_DONE_PARTITIONS: usize = 3; - #[derive(Default)] - struct Slice { - total_bytes: u64, - max_elapsed_ns: u64, - msg_count: usize, - done: bool, - } - let mut slices: HashMap<(usize, u64), Slice> = HashMap::new(); - let mut done_count: usize = 0; - let mut deadline: Option = None; - - loop { - let next = match deadline { - None => load_info_stream.next().await, - Some(d) => match tokio::time::timeout_at(d.into(), load_info_stream.next()).await { - Ok(item) => item, - Err(_) => break, // sampling window elapsed - }, - }; - let Some((task_idx, batch)) = next else { break }; // stream terminated + let mut votes = vec![]; - for info in batch.batch { - let entry = slices.entry((task_idx, info.partition)).or_default(); - entry.total_bytes = entry.total_bytes.saturating_add(info.byte_size); - entry.max_elapsed_ns = entry.max_elapsed_ns.max(info.time_mark_ns); - entry.msg_count += 1; - if !entry.done - && (info.eos - || info.max_memory_reached - || entry.msg_count >= MAX_MESSAGES_PER_SLICE) - { - entry.done = true; - done_count += 1; - } - } - if deadline.is_none() && !slices.is_empty() { - deadline = Some(Instant::now() + SAMPLING_WINDOW); + while let Some((_task_idx, load_info)) = load_info_stream.next().await { + partitions_remaining -= 1; + if load_info.bytes_per_second == 0 { + // The partition reporting this load info was empty. + continue; } - if done_count >= TARGET_DONE_SLICES { + votes.push(load_info); + if votes.len() >= TARGET_DONE_PARTITIONS || partitions_remaining == 0 { break; } } - // Each slice that observed enough data votes for a task count. The vote extrapolates the - // slice's observed velocity to the full producer (assumes all observed slices share the - // same velocity): - // slice_velocity = total_bytes / max_elapsed_ns (bytes/ns) - // stage_throughput = slice_velocity * num_slices_observed * 1e9 (bytes/sec) - // vote = ceil(stage_throughput / TARGET_BYTES_PER_SEC_PER_TASK) - let observed = slices.len().max(1) as u128; - let mut votes: Vec = slices - .values() - .filter_map(|s| { - if s.max_elapsed_ns == 0 || s.total_bytes == 0 { - return None; - } - let numerator = (s.total_bytes as u128) - .saturating_mul(1_000_000_000) - .saturating_mul(observed); - let denominator = - (s.max_elapsed_ns as u128).saturating_mul(TARGET_BYTES_PER_SEC_PER_TASK as u128); - Some(numerator.div_ceil(denominator).max(1)) - }) - .collect(); - - // Floor at the number of distinct producer task_idxs observed: never shrink the consumer - // stage below the producer's parallelism. Mirrors the static `CardinalityTaskCountStrategy` - // behavior where a consumer at least matches its producer. - let producer_task_floor = slices - .keys() - .map(|(t, _)| *t) - .collect::>() - .len() as u128; - - if votes.is_empty() { - return producer_task_floor.max(1) as usize; - } + let avg_bytes_per_second_per_partition = + votes.iter().map(|v| v.bytes_per_second).sum::() / votes.len() as u64; - votes.sort_unstable(); - let median = votes[votes.len() / 2].max(producer_task_floor); - usize::try_from(median).unwrap_or(usize::MAX) + (avg_bytes_per_second_per_partition as usize).div_ceil(bytes_per_partition_per_second) } diff --git a/src/coordinator/task_spawner.rs b/src/coordinator/task_spawner.rs index 8bc04753..e7e5b223 100644 --- a/src/coordinator/task_spawner.rs +++ b/src/coordinator/task_spawner.rs @@ -239,7 +239,7 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { &mut self, task_i: usize, mut worker_to_coordinator_rx: UnboundedReceiver, - ) -> UnboundedReceiver { + ) -> UnboundedReceiver { let task_key = TaskKey { query_id: serialize_uuid(&self.query_id), stage_id: self.stage_id as u64, @@ -258,8 +258,8 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { task_metrics.insert(task_key.clone(), pre_order_metrics.metrics); } } - pb::worker_to_coordinator_msg::Inner::LoadInfoBatch(load_info_batch) => { - let _ = tx.send(load_info_batch); + pb::worker_to_coordinator_msg::Inner::LoadInfo(load_info) => { + let _ = tx.send(load_info); } } } diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index f0bbea41..a2d08da4 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -589,6 +589,20 @@ pub trait DistributedExt: Sized { /// Same as [DistributedExt::with_distributed_dynamic_task_count] but with an in-place mutation. fn set_distributed_dynamic_task_count(&mut self, enabled: bool) -> Result<(), DataFusionError>; + + /// Target throughput in bytes per partition per second used by the dynamic task count + /// allocator to decide how many tasks to assign to each stage based on runtime statistics. + fn with_distributed_bytes_per_partition_per_second( + self, + bytes_per_partition_per_second: usize, + ) -> Result; + + /// Same as [DistributedExt::with_distributed_bytes_per_partition_per_second] but with an + /// in-place mutation. + fn set_distributed_bytes_per_partition_per_second( + &mut self, + bytes_per_partition_per_second: usize, + ) -> Result<(), DataFusionError>; } impl DistributedExt for SessionConfig { @@ -743,6 +757,15 @@ impl DistributedExt for SessionConfig { Ok(()) } + fn set_distributed_bytes_per_partition_per_second( + &mut self, + bytes_per_partition_per_second: usize, + ) -> Result<(), DataFusionError> { + let d_cfg = DistributedConfig::from_config_options_mut(self.options_mut())?; + d_cfg.bytes_per_partition_per_second = bytes_per_partition_per_second; + Ok(()) + } + delegate! { to self { #[call(set_distributed_option_extension)] @@ -829,6 +852,10 @@ impl DistributedExt for SessionConfig { #[call(set_distributed_dynamic_task_count)] #[expr($?;Ok(self))] fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; + + #[call(set_distributed_bytes_per_partition_per_second)] + #[expr($?;Ok(self))] + fn with_distributed_bytes_per_partition_per_second(mut self, bytes_per_partition_per_second: usize) -> Result; } } } @@ -945,6 +972,11 @@ impl DistributedExt for SessionStateBuilder { #[call(set_distributed_dynamic_task_count)] #[expr($?;Ok(self))] fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; + + fn set_distributed_bytes_per_partition_per_second(&mut self, bytes_per_partition_per_second: usize) -> Result<(), DataFusionError>; + #[call(set_distributed_bytes_per_partition_per_second)] + #[expr($?;Ok(self))] + fn with_distributed_bytes_per_partition_per_second(mut self, bytes_per_partition_per_second: usize) -> Result; } } } @@ -1061,6 +1093,11 @@ impl DistributedExt for SessionState { #[call(set_distributed_dynamic_task_count)] #[expr($?;Ok(self))] fn with_distributed_dynamic_task_count(mut self, enabled: bool) -> Result; + + fn set_distributed_bytes_per_partition_per_second(&mut self, bytes_per_partition_per_second: usize) -> Result<(), DataFusionError>; + #[call(set_distributed_bytes_per_partition_per_second)] + #[expr($?;Ok(self))] + fn with_distributed_bytes_per_partition_per_second(mut self, bytes_per_partition_per_second: usize) -> Result; } } } @@ -1177,6 +1214,11 @@ impl DistributedExt for SessionContext { #[call(set_distributed_dynamic_task_count)] #[expr($?;Ok(self))] fn with_distributed_dynamic_task_count(self, enabled: bool) -> Result; + + fn set_distributed_bytes_per_partition_per_second(&mut self, bytes_per_partition_per_second: usize) -> Result<(), DataFusionError>; + #[call(set_distributed_bytes_per_partition_per_second)] + #[expr($?;Ok(self))] + fn with_distributed_bytes_per_partition_per_second(self, bytes_per_partition_per_second: usize) -> Result; } } } diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs index 0415136f..059ef4a6 100644 --- a/src/distributed_planner/distributed_config.rs +++ b/src/distributed_planner/distributed_config.rs @@ -67,6 +67,8 @@ extensions_options! { pub worker_connection_buffer_budget_bytes: usize, default = 64 * 1024 * 1024 /// TODO pub dynamic_task_count: bool, default = false + /// TODO + pub bytes_per_partition_per_second: usize, default = 256 * 1024 * 1024 /// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to /// estimate how many tasks should be spawned for the [Stage] containing the leaf node. pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default() diff --git a/src/execution_plans/sampler.rs b/src/execution_plans/sampler.rs index 6fcef85a..756bcf80 100644 --- a/src/execution_plans/sampler.rs +++ b/src/execution_plans/sampler.rs @@ -1,15 +1,17 @@ use crate::common::require_one_child; use crate::worker::generated::worker as pb; -use crate::{LatencyMetricExt, MaxLatencyMetric, P50LatencyMetric}; -use datafusion::arrow::array::RecordBatch; -use datafusion::common::runtime::SpawnedTask; +use crate::{ + BytesCounterMetric, BytesMetricExt, LatencyMetricExt, MaxLatencyMetric, P50LatencyMetric, +}; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::{Result, exec_err}; use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr_common::metrics::{Gauge, MetricsSet}; use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::stream::{ + RecordBatchReceiverStreamBuilder, RecordBatchStreamAdapter, +}; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use futures::StreamExt; use std::any::Any; @@ -18,9 +20,8 @@ use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use tokio::sync::Notify; use tokio::sync::mpsc::UnboundedSender; -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio::sync::mpsc::error::TrySendError; #[derive(Debug)] pub struct SamplerExec { @@ -41,10 +42,12 @@ pub(crate) struct SamplerExecMetrics { /// Time from the first `LoadInfo` message being emitted until `execute()` is invoked. /// Measures how long the coordinator sat between seeing the first sample and starting /// the consumer. - first_load_info_to_execution_p50: P50LatencyMetric, - first_load_info_to_execution_max: MaxLatencyMetric, + first_batch_to_execution_p50: P50LatencyMetric, + first_batch_to_execution_max: MaxLatencyMetric, /// Peak memory buffered by any partition sampler during the sampling phase. max_mem_used: Gauge, + /// Bytes per second flowing through the sampler node. + bytes_per_sec: BytesCounterMetric, } impl SamplerExecMetrics { @@ -54,11 +57,12 @@ impl SamplerExecMetrics { .p50_latency("kickoff_to_execution_p50"), kickoff_to_execution_max: MetricBuilder::new(metric_set) .max_latency("kickoff_to_execution_max"), - first_load_info_to_execution_p50: MetricBuilder::new(metric_set) - .p50_latency("first_load_info_to_execution_p50"), - first_load_info_to_execution_max: MetricBuilder::new(metric_set) - .max_latency("first_load_info_to_execution_max"), + first_batch_to_execution_p50: MetricBuilder::new(metric_set) + .p50_latency("first_batch_to_execution_p50"), + first_batch_to_execution_max: MetricBuilder::new(metric_set) + .max_latency("first_batch_to_execution_max"), max_mem_used: MetricBuilder::new(metric_set).global_gauge("max_mem_used"), + bytes_per_sec: MetricBuilder::new(metric_set).bytes_counter("bytes_per_sec"), } } } @@ -77,7 +81,7 @@ impl SamplerExec { execution_flag: Arc::new(AtomicBool::new(false)), metrics: metrics.clone(), kick_off_at: Arc::new(OnceLock::new()), - first_load_info_at: Arc::new(OnceLock::new()), + first_batch_at: Arc::new(OnceLock::new()), }); } Self { @@ -118,7 +122,7 @@ pub(crate) struct PartitionSampler { kick_off_at: Arc>, /// Set the first time the producer task emits a `LoadInfo`. Used at `execute()` time /// to record the gap between the first sample and the consumer starting. - first_load_info_at: Arc>, + first_batch_at: Arc>, } impl Debug for PartitionSampler { @@ -136,13 +140,13 @@ impl PartitionSampler { self.metrics.kickoff_to_execution_p50.add_duration(delay); self.metrics.kickoff_to_execution_max.add_duration(delay); } - if let Some(t) = self.first_load_info_at.get() { + if let Some(t) = self.first_batch_at.get() { let delay = now.saturating_duration_since(*t); self.metrics - .first_load_info_to_execution_p50 + .first_batch_to_execution_p50 .add_duration(delay); self.metrics - .first_load_info_to_execution_max + .first_batch_to_execution_max .add_duration(delay); } self.stream.lock().unwrap().take() @@ -165,115 +169,102 @@ impl PartitionSampler { ); let memory_reservation_for_consumer = Arc::clone(&memory_reservation); - // Producer pauses when the buffer exceeds the budget; consumers wake it - // via this Notify after each shrink. Without the gate, the unbounded - // queue could grow without bound while the next stage hasn't started. - let mem_available_notify = Arc::new(Notify::new()); - let mem_available_notify_for_consumer = Arc::clone(&mem_available_notify); - let execution_flag = Arc::clone(&self.execution_flag); - let max_mem_used = self.metrics.max_mem_used.clone(); - let first_load_info_at = Arc::clone(&self.first_load_info_at); + let max_mem_used_metric = self.metrics.max_mem_used.clone(); + let bytes_per_sec_metric = self.metrics.bytes_per_sec.clone(); + let first_batch_at = Arc::clone(&self.first_batch_at); // Execute the input synchronously so any setup error surfaces before we // spawn the producer task. let mut input_stream = input.execute(partition_idx, ctx)?; - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + let mut builder = RecordBatchReceiverStreamBuilder::new(self.input.schema(), 2); + let tx = builder.tx(); - let task = SpawnedTask::spawn(async move { - let mut first_msg_ns = None; - let mut max_mem_signaled = false; + let mut first_msg = None; + let mut sampling_tx = Some(sampling_tx); + builder.spawn(async move { while let Some(batch_or_err) = input_stream.next().await { - let batch = match batch_or_err { - Ok(b) => b, - Err(e) => { - let _ = tx.send(Err(e)); - return; - } - }; + let batch = batch_or_err?; let size = batch.get_array_memory_size(); - // Backpressure: pause once the buffer exceeds the budget. The - // consumer (real execution) wakes us after each batch it drains. - while memory_reservation.size() >= PARTITION_SAMPLER_BUDGET_BYTES { - mem_available_notify.notified().await; - } + let now = Instant::now(); + let first_msg = first_msg.get_or_insert_with(|| { + let _ = first_batch_at.set(now); + now + }); + let time_since_first_msg = now - *first_msg; memory_reservation.grow(size); - max_mem_used.set_max(memory_reservation.size()); - - if !execution_flag.load(Ordering::Relaxed) { - let time = Instant::now(); - let time_mark_ns = match first_msg_ns { - Some(t) => time - t, - None => { - first_msg_ns = Some(time); - let _ = first_load_info_at.set(time); - Duration::default() - } + let memory_used = memory_reservation.size(); + max_mem_used_metric.set_max(memory_used); + + if sampling_tx.is_none() || execution_flag.load(Ordering::Relaxed) { + if tx.send(Ok(batch)).await.is_err() { + return Ok(()); // channel closed }; - let max_memory_reached = !max_mem_signaled - && memory_reservation.size() >= PARTITION_SAMPLER_BUDGET_BYTES; - if max_memory_reached { - max_mem_signaled = true; - } - let _ = sampling_tx.send(pb::LoadInfo { - partition: partition_idx as u64, - row_count: batch.num_rows() as u64, - byte_size: size as u64, - time_mark_ns: time_mark_ns.as_nanos() as u64, - eos: false, - max_memory_reached, - }); + continue; } - if tx.send(Ok(batch)).is_err() { - return; + match tx.try_send(Ok(batch)) { + Ok(_) => {} + Err(TrySendError::Full(batch_or_err)) => { + let bytes_per_second = bytes_per_second(memory_used, time_since_first_msg); + bytes_per_sec_metric.add_bytes(bytes_per_second as usize); + if let Some(sampling_tx) = sampling_tx.take() { + let _ = sampling_tx.send(pb::LoadInfo { + partition: partition_idx as u64, + bytes_per_second, + eos: false, + }); + } + if tx.send(batch_or_err).await.is_err() { + return Ok(()); // channel closed + }; + } + Err(TrySendError::Closed(_)) => return Ok(()), } } - // End of input: if real execution hasn't started yet, tell the - // coordinator we observed the entire stream. - if !execution_flag.load(Ordering::Relaxed) { - let time_mark_ns = first_msg_ns.map(|t| Instant::now() - t).unwrap_or_default(); + if let Some(sampling_tx) = sampling_tx.take() { let _ = sampling_tx.send(pb::LoadInfo { partition: partition_idx as u64, - row_count: 0, - byte_size: 0, - time_mark_ns: time_mark_ns.as_nanos() as u64, + bytes_per_second: match first_batch_at.get() { + Some(v) => { + bytes_per_second(max_mem_used_metric.value(), Instant::now() - *v) + } + None => 0, + }, eos: true, - max_memory_reached: false, }); } + + Ok(()) }); - let stream = UnboundedReceiverStream::new(rx).map(move |result| { - let _ = &task; // keep the task alive as long as the stream is alive. + let stream = builder.build().inspect(move |result| { if let Ok(batch) = &result { memory_reservation_for_consumer.shrink(batch.get_array_memory_size()); - mem_available_notify_for_consumer.notify_one(); } - result }); - let stream = RecordBatchStreamAdapter::new(schema, stream); self.stream .lock() .expect("poisoned lock") - .replace(Box::pin(stream)); + .replace(Box::pin(RecordBatchStreamAdapter::new(schema, stream))); Ok(()) } } -/// Soft byte budget the partition sampler will buffer before pausing the -/// producer. Once exceeded, a `max_memory_reached` LoadInfo is emitted once, -/// signaling to the coordinator that the producer can sustain at least this -/// much in-flight data. -/// TODO: make this configurable via DistributedConfig. -const PARTITION_SAMPLER_BUDGET_BYTES: usize = 32 * 1024 * 1024; +fn bytes_per_second(bytes: usize, time: Duration) -> u64 { + let secs = time.as_secs_f32(); + if secs == 0.0 { + return 0; + } + ((bytes as f32) / secs) as u64 +} impl DisplayAs for SamplerExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index 6664418d..5a604583 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -37,7 +37,7 @@ pub mod worker_to_coordinator_msg { /// Load information reported by a task. This information is used for dynamically /// sizing the number of workers involved in a query. #[prost(message, tag = "2")] - LoadInfoBatch(super::LoadInfoBatch), + LoadInfo(super::LoadInfo), } } /// Metrics for a single task's plan nodes in pre-order traversal order. @@ -48,30 +48,14 @@ pub struct PreOrderTaskMetrics { #[prost(message, repeated, tag = "1")] pub metrics: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct LoadInfoBatch { - #[prost(message, repeated, tag = "1")] - pub batch: ::prost::alloc::vec::Vec, -} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct LoadInfo { #[prost(uint64, tag = "1")] pub partition: u64, #[prost(uint64, tag = "2")] - pub row_count: u64, - #[prost(uint64, tag = "3")] - pub byte_size: u64, - #[prost(bool, tag = "4")] + pub bytes_per_second: u64, + #[prost(bool, tag = "3")] pub eos: bool, - #[prost(uint64, tag = "5")] - pub time_mark_ns: u64, - /// True on the first sample emitted after the partition's sampler buffer - /// exceeded its configured byte budget. Consumers (the coordinator's - /// calculate_task_count) treat this as a "saturation" signal that the - /// producer can produce at least this much data, and can decide the next - /// stage's task count without waiting for further samples. - #[prost(bool, tag = "6")] - pub max_memory_reached: bool, } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct GetWorkerInfoRequest {} @@ -97,8 +81,9 @@ pub struct SetPlanRequest { /// /// If no WorkUnitFeedExec nodes are present in the plan, this should be empty. #[prost(message, repeated, tag = "4")] - pub work_unit_feed_declarations: - ::prost::alloc::vec::Vec, + pub work_unit_feed_declarations: ::prost::alloc::vec::Vec< + set_plan_request::WorkUnitFeedDeclaration, + >, /// The worker URL to which this message will go. The receiving worker will use this information to identify /// itself, and avoid further gRPC calls in case it needs to call itself for executing remote tasks. #[prost(string, tag = "5")] @@ -390,10 +375,10 @@ pub mod worker_service_client { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value + clippy::let_unit_value, )] - use tonic::codegen::http::Uri; use tonic::codegen::*; + use tonic::codegen::http::Uri; #[derive(Debug, Clone)] pub struct WorkerServiceClient { inner: tonic::client::Grpc, @@ -432,13 +417,14 @@ pub mod worker_service_client { F: tonic::service::Interceptor, T::ResponseBody: Default, T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, + http::Request, + Response = http::Response< + >::ResponseBody, >, - >>::Error: - Into + std::marker::Send + std::marker::Sync, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, { WorkerServiceClient::new(InterceptedService::new(inner, interceptor)) } @@ -478,22 +464,28 @@ pub mod worker_service_client { /// per task. pub async fn coordinator_channel( &mut self, - request: impl tonic::IntoStreamingRequest, + request: impl tonic::IntoStreamingRequest< + Message = super::CoordinatorToWorkerMsg, + >, ) -> std::result::Result< tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::unknown(format!("Service was not ready: {}", e.into())) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/worker.WorkerService/CoordinatorChannel"); + let path = http::uri::PathAndQuery::from_static( + "/worker.WorkerService/CoordinatorChannel", + ); let mut req = request.into_streaming_request(); - req.extensions_mut().insert(GrpcMethod::new( - "worker.WorkerService", - "CoordinatorChannel", - )); + req.extensions_mut() + .insert(GrpcMethod::new("worker.WorkerService", "CoordinatorChannel")); self.inner.streaming(req, path, codec).await } /// Executes the requested partition range of a subplan previously sent by the coordinator channel. @@ -504,11 +496,18 @@ pub mod worker_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::unknown(format!("Service was not ready: {}", e.into())) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/worker.WorkerService/ExecuteTask"); + let path = http::uri::PathAndQuery::from_static( + "/worker.WorkerService/ExecuteTask", + ); let mut req = request.into_request(); req.extensions_mut() .insert(GrpcMethod::new("worker.WorkerService", "ExecuteTask")); @@ -518,13 +517,22 @@ pub mod worker_service_client { pub async fn get_worker_info( &mut self, request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> - { - self.inner.ready().await.map_err(|e| { - tonic::Status::unknown(format!("Service was not ready: {}", e.into())) - })?; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/worker.WorkerService/GetWorkerInfo"); + let path = http::uri::PathAndQuery::from_static( + "/worker.WorkerService/GetWorkerInfo", + ); let mut req = request.into_request(); req.extensions_mut() .insert(GrpcMethod::new("worker.WorkerService", "GetWorkerInfo")); @@ -539,7 +547,7 @@ pub mod worker_service_server { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value + clippy::let_unit_value, )] use tonic::codegen::*; /// Generated trait containing gRPC methods that should be implemented for use with WorkerServiceServer. @@ -548,7 +556,8 @@ pub mod worker_service_server { /// Server streaming response type for the CoordinatorChannel method. type CoordinatorChannelStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, - > + std::marker::Send + > + + std::marker::Send + 'static; /// Establishes a bidirectional message stream between a coordinator and a worker, over which messages /// will be exchanged at any time during a query's lifetime. It's expected to be one coordinator channel @@ -556,22 +565,32 @@ pub mod worker_service_server { async fn coordinator_channel( &self, request: tonic::Request>, - ) -> std::result::Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; /// Server streaming response type for the ExecuteTask method. type ExecuteTaskStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result<::arrow_flight::FlightData, tonic::Status>, - > + std::marker::Send + > + + std::marker::Send + 'static; /// Executes the requested partition range of a subplan previously sent by the coordinator channel. async fn execute_task( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; /// Returns metadata about a worker. Currently only used for worker versioning. async fn get_worker_info( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; } #[derive(Debug)] pub struct WorkerServiceServer { @@ -594,7 +613,10 @@ pub mod worker_service_server { max_encoding_message_size: None, } } - pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService where F: tonic::service::Interceptor, { @@ -649,14 +671,16 @@ pub mod worker_service_server { "/worker.WorkerService/CoordinatorChannel" => { #[allow(non_camel_case_types)] struct CoordinatorChannelSvc(pub Arc); - impl - tonic::server::StreamingService - for CoordinatorChannelSvc - { + impl< + T: WorkerService, + > tonic::server::StreamingService + for CoordinatorChannelSvc { type Response = super::WorkerToCoordinatorMsg; type ResponseStream = T::CoordinatorChannelStream; - type Future = - BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request< @@ -665,7 +689,8 @@ pub mod worker_service_server { ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - ::coordinator_channel(&inner, request).await + ::coordinator_channel(&inner, request) + .await }; Box::pin(fut) } @@ -695,14 +720,16 @@ pub mod worker_service_server { "/worker.WorkerService/ExecuteTask" => { #[allow(non_camel_case_types)] struct ExecuteTaskSvc(pub Arc); - impl - tonic::server::ServerStreamingService - for ExecuteTaskSvc - { + impl< + T: WorkerService, + > tonic::server::ServerStreamingService + for ExecuteTaskSvc { type Response = ::arrow_flight::FlightData; type ResponseStream = T::ExecuteTaskStream; - type Future = - BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, @@ -739,11 +766,15 @@ pub mod worker_service_server { "/worker.WorkerService/GetWorkerInfo" => { #[allow(non_camel_case_types)] struct GetWorkerInfoSvc(pub Arc); - impl tonic::server::UnaryService - for GetWorkerInfoSvc - { + impl< + T: WorkerService, + > tonic::server::UnaryService + for GetWorkerInfoSvc { type Response = super::GetWorkerInfoResponse; - type Future = BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, @@ -777,19 +808,25 @@ pub mod worker_service_server { }; Box::pin(fut) } - _ => Box::pin(async move { - let mut response = http::Response::new(tonic::body::Body::default()); - let headers = response.headers_mut(); - headers.insert( - tonic::Status::GRPC_STATUS, - (tonic::Code::Unimplemented as i32).into(), - ); - headers.insert( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ); - Ok(response) - }), + _ => { + Box::pin(async move { + let mut response = http::Response::new( + tonic::body::Body::default(), + ); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } } } } diff --git a/src/worker/impl_coordinator_channel.rs b/src/worker/impl_coordinator_channel.rs index 3d101ea0..c029e1fe 100644 --- a/src/worker/impl_coordinator_channel.rs +++ b/src/worker/impl_coordinator_channel.rs @@ -148,11 +148,7 @@ impl Worker { let load_info_stream = UnboundedReceiverStream::new(load_info_rx); let load_info_stream = load_info_stream.map(|load_info| WorkerToCoordinatorMsg { - inner: Some(worker_to_coordinator_msg::Inner::LoadInfoBatch( - crate::worker::generated::worker::LoadInfoBatch { - batch: vec![load_info], - }, - )), + inner: Some(worker_to_coordinator_msg::Inner::LoadInfo(load_info)), }); // Stream back the metrics once the task finishes executing. diff --git a/src/worker/worker.proto b/src/worker/worker.proto index 2ebeb88d..af88168f 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -34,7 +34,7 @@ message WorkerToCoordinatorMsg { // Load information reported by a task. This information is used for dynamically // sizing the number of workers involved in a query. - LoadInfoBatch load_info_batch = 2; + LoadInfo load_info = 2; } } @@ -45,22 +45,10 @@ message PreOrderTaskMetrics { repeated MetricsSet metrics = 1; } -message LoadInfoBatch { - repeated LoadInfo batch = 1; -} - message LoadInfo { uint64 partition = 1; - uint64 row_count = 2; - uint64 byte_size = 3; - bool eos = 4; - uint64 time_mark_ns = 5; - // True on the first sample emitted after the partition's sampler buffer - // exceeded its configured byte budget. Consumers (the coordinator's - // calculate_task_count) treat this as a "saturation" signal that the - // producer can produce at least this much data, and can decide the next - // stage's task count without waiting for further samples. - bool max_memory_reached = 6; + uint64 bytes_per_second = 2; + bool eos = 3; } message GetWorkerInfoRequest {} From 1b650389db1c4b0bbfd9377a99ce9c7d0695c6ab Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 13 May 2026 10:33:01 -0400 Subject: [PATCH 3/7] Fix races --- src/coordinator/prepare_dynamic_plan.rs | 3 + src/coordinator/task_spawner.rs | 12 ++- src/execution_plans/sampler.rs | 118 ++++++++++-------------- src/worker/generated/worker.rs | 4 +- src/worker/impl_coordinator_channel.rs | 14 ++- src/worker/worker.proto | 2 + 6 files changed, 75 insertions(+), 78 deletions(-) diff --git a/src/coordinator/prepare_dynamic_plan.rs b/src/coordinator/prepare_dynamic_plan.rs index 395c5222..f19af41f 100644 --- a/src/coordinator/prepare_dynamic_plan.rs +++ b/src/coordinator/prepare_dynamic_plan.rs @@ -260,6 +260,9 @@ async fn calculate_necessary_partitions( break; } } + if votes.is_empty() { + return 1; + } let avg_bytes_per_second_per_partition = votes.iter().map(|v| v.bytes_per_second).sum::() / votes.len() as u64; diff --git a/src/coordinator/task_spawner.rs b/src/coordinator/task_spawner.rs index e7e5b223..0c18d8ea 100644 --- a/src/coordinator/task_spawner.rs +++ b/src/coordinator/task_spawner.rs @@ -246,7 +246,8 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { task_number: task_i as u64, }; let task_metrics = self.task_metrics.clone(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (load_info_tx, load_info_rx) = tokio::sync::mpsc::unbounded_channel(); + let mut load_info_tx_opt = Some(load_info_tx); #[allow(clippy::disallowed_methods)] tokio::spawn(async move { while let Some(msg) = worker_to_coordinator_rx.recv().await { @@ -259,12 +260,17 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { } } pb::worker_to_coordinator_msg::Inner::LoadInfo(load_info) => { - let _ = tx.send(load_info); + if let Some(tx) = &load_info_tx_opt { + let _ = tx.send(load_info); + } + } + pb::worker_to_coordinator_msg::Inner::LoadInfoEos(_) => { + let _ = load_info_tx_opt.take(); } } } }); - rx + load_info_rx } /// Launches the task that based on the different local [WorkUnitFeedExec] nodes, sends their diff --git a/src/execution_plans/sampler.rs b/src/execution_plans/sampler.rs index 756bcf80..65e93ff3 100644 --- a/src/execution_plans/sampler.rs +++ b/src/execution_plans/sampler.rs @@ -3,17 +3,17 @@ use crate::worker::generated::worker as pb; use crate::{ BytesCounterMetric, BytesMetricExt, LatencyMetricExt, MaxLatencyMetric, P50LatencyMetric, }; +use datafusion::common::runtime::SpawnedTask; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion::common::{Result, exec_err}; +use datafusion::common::{DataFusionError, Result, exec_err}; use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr_common::metrics::{Gauge, MetricsSet}; +use datafusion::physical_plan::buffer::SizedMessage; use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use datafusion::physical_plan::stream::{ - RecordBatchReceiverStreamBuilder, RecordBatchStreamAdapter, -}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::StreamExt; +use futures::{StreamExt, TryFutureExt}; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::OnceLock; @@ -21,7 +21,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::TrySendError; #[derive(Debug)] pub struct SamplerExec { @@ -167,9 +166,7 @@ impl PartitionSampler { MemoryConsumer::new(format!("PartitionSampler[{partition_idx}]")) .register(ctx.memory_pool()), ); - let memory_reservation_for_consumer = Arc::clone(&memory_reservation); - let execution_flag = Arc::clone(&self.execution_flag); let max_mem_used_metric = self.metrics.max_mem_used.clone(); let bytes_per_sec_metric = self.metrics.bytes_per_sec.clone(); let first_batch_at = Arc::clone(&self.first_batch_at); @@ -178,76 +175,57 @@ impl PartitionSampler { // spawn the producer task. let mut input_stream = input.execute(partition_idx, ctx)?; - let mut builder = RecordBatchReceiverStreamBuilder::new(self.input.schema(), 2); - let tx = builder.tx(); - - let mut first_msg = None; - let mut sampling_tx = Some(sampling_tx); - - builder.spawn(async move { - while let Some(batch_or_err) = input_stream.next().await { - let batch = batch_or_err?; - let size = batch.get_array_memory_size(); - - let now = Instant::now(); - let first_msg = first_msg.get_or_insert_with(|| { - let _ = first_batch_at.set(now); - now + let task = SpawnedTask::spawn(async move { + let Some(first_batch) = input_stream.next().await else { + let _ = sampling_tx.send(pb::LoadInfo { + partition: partition_idx as u64, + eos: true, + ..Default::default() }); - let time_since_first_msg = now - *first_msg; - - memory_reservation.grow(size); - let memory_used = memory_reservation.size(); - max_mem_used_metric.set_max(memory_used); - - if sampling_tx.is_none() || execution_flag.load(Ordering::Relaxed) { - if tx.send(Ok(batch)).await.is_err() { - return Ok(()); // channel closed - }; - continue; - } - - match tx.try_send(Ok(batch)) { - Ok(_) => {} - Err(TrySendError::Full(batch_or_err)) => { - let bytes_per_second = bytes_per_second(memory_used, time_since_first_msg); - bytes_per_sec_metric.add_bytes(bytes_per_second as usize); - if let Some(sampling_tx) = sampling_tx.take() { - let _ = sampling_tx.send(pb::LoadInfo { - partition: partition_idx as u64, - bytes_per_second, - eos: false, - }); - } - if tx.send(batch_or_err).await.is_err() { - return Ok(()); // channel closed - }; - } - Err(TrySendError::Closed(_)) => return Ok(()), - } - } - - if let Some(sampling_tx) = sampling_tx.take() { + return Ok(futures::stream::empty().boxed()); + }; + let first_batch = first_batch?; + let first_batch_time = Instant::now(); + let first_batch_size = first_batch.size(); + let _ = first_batch_at.set(first_batch_time); + max_mem_used_metric.add(first_batch_size); + memory_reservation.grow(first_batch_size); + + let Some(second_batch) = input_stream.next().await else { let _ = sampling_tx.send(pb::LoadInfo { partition: partition_idx as u64, - bytes_per_second: match first_batch_at.get() { - Some(v) => { - bytes_per_second(max_mem_used_metric.value(), Instant::now() - *v) - } - None => 0, - }, eos: true, + ..Default::default() }); - } + return Ok(futures::stream::iter([Ok(first_batch)]).boxed()); + }; + let second_batch = second_batch?; + let second_batch_time = Instant::now(); + let second_batch_size = second_batch.size(); + max_mem_used_metric.add(second_batch_size); + memory_reservation.grow(second_batch_size); + + let bytes_per_second = bytes_per_second( + first_batch.get_array_memory_size() + second_batch.get_array_memory_size(), + second_batch_time - first_batch_time, + ); + bytes_per_sec_metric.add_bytes(bytes_per_second as usize); + let _ = sampling_tx.send(pb::LoadInfo { + partition: partition_idx as u64, + bytes_per_second, + eos: false, + }); - Ok(()) + Ok(futures::stream::iter([Ok(first_batch), Ok(second_batch)]) + .chain(input_stream) + .boxed()) }); - let stream = builder.build().inspect(move |result| { - if let Ok(batch) = &result { - memory_reservation_for_consumer.shrink(batch.get_array_memory_size()); - } - }); + let stream = async move { + task.await + .map_err(|err| DataFusionError::Internal(err.to_string()))? + } + .try_flatten_stream(); self.stream .lock() diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index 5a604583..b62af68c 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -21,7 +21,7 @@ pub mod coordinator_to_worker_msg { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct WorkerToCoordinatorMsg { - #[prost(oneof = "worker_to_coordinator_msg::Inner", tags = "1, 2")] + #[prost(oneof = "worker_to_coordinator_msg::Inner", tags = "1, 2, 3")] pub inner: ::core::option::Option, } /// Nested message and enum types in `WorkerToCoordinatorMsg`. @@ -38,6 +38,8 @@ pub mod worker_to_coordinator_msg { /// sizing the number of workers involved in a query. #[prost(message, tag = "2")] LoadInfo(super::LoadInfo), + #[prost(bool, tag = "3")] + LoadInfoEos(bool), } } /// Metrics for a single task's plan nodes in pre-order traversal order. diff --git a/src/worker/impl_coordinator_channel.rs b/src/worker/impl_coordinator_channel.rs index c029e1fe..cd9de5ea 100644 --- a/src/worker/impl_coordinator_channel.rs +++ b/src/worker/impl_coordinator_channel.rs @@ -146,10 +146,16 @@ impl Worker { } }); - let load_info_stream = UnboundedReceiverStream::new(load_info_rx); - let load_info_stream = load_info_stream.map(|load_info| WorkerToCoordinatorMsg { - inner: Some(worker_to_coordinator_msg::Inner::LoadInfo(load_info)), - }); + let load_info_stream = UnboundedReceiverStream::new(load_info_rx) + .map(|load_info| WorkerToCoordinatorMsg { + inner: Some(worker_to_coordinator_msg::Inner::LoadInfo(load_info)), + }) + .chain(futures::stream::once(async move { + WorkerToCoordinatorMsg { + inner: Some(worker_to_coordinator_msg::Inner::LoadInfoEos(true)), + } + })) + .boxed(); // Stream back the metrics once the task finishes executing. // The oneshot receiver resolves when impl_execute_task sends the collected diff --git a/src/worker/worker.proto b/src/worker/worker.proto index af88168f..edcbd091 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -35,6 +35,8 @@ message WorkerToCoordinatorMsg { // Load information reported by a task. This information is used for dynamically // sizing the number of workers involved in a query. LoadInfo load_info = 2; + + bool load_info_eos = 3; } } From 8926d402bd944cb4c4a12227d9e7ac296df3b25d Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Fri, 15 May 2026 17:04:15 +0200 Subject: [PATCH 4/7] Solve conflicts --- src/distributed_planner/inject_network_boundaries.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distributed_planner/inject_network_boundaries.rs b/src/distributed_planner/inject_network_boundaries.rs index d143810f..71226ef8 100644 --- a/src/distributed_planner/inject_network_boundaries.rs +++ b/src/distributed_planner/inject_network_boundaries.rs @@ -307,7 +307,7 @@ async fn _inject_network_boundaries( tasks: task_count.as_usize(), }); let result = ctx.apply_stage_builder(Arc::new(plan)).await?; - return Ok(ctx.with_task_count(result.network_boundary, result.task_count_above)); + return Ok(ctx.plan_with_task_count(result.network_boundary, result.task_count_above)); } // If the parent of the current node is either a `CoalescePartitionsExec` or a // `SortPreservingMergeExec`, a network boundary below it is necessary. @@ -333,7 +333,7 @@ async fn _inject_network_boundaries( tasks: task_count.as_usize(), }); let result = ctx.apply_stage_builder(Arc::new(plan)).await?; - return Ok(ctx.with_task_count(result.network_boundary, result.task_count_above)); + return Ok(ctx.plan_with_task_count(result.network_boundary, result.task_count_above)); } else { // The subtree below this point belongs to one stage. Propagate the chosen task // count down so every node in that stage has it recorded. @@ -356,7 +356,7 @@ async fn _inject_network_boundaries( // The parent that triggered this branch is a `CoalescePartitionsExec` or // `SortPreservingMergeExec`, both of which fold all partitions into one — so the // stage above this boundary must run in exactly one task. - Ok(ctx.with_task_count(result.network_boundary, Maximum(1))) + Ok(ctx.plan_with_task_count(result.network_boundary, Maximum(1))) }; } From a3410a41bf3bb49fcdd3419d2e8bd895700940e9 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sun, 17 May 2026 21:28:57 +0200 Subject: [PATCH 5/7] Remove eos field from proto --- src/execution_plans/sampler.rs | 7 ++----- src/worker/generated/worker.rs | 2 -- src/worker/worker.proto | 1 - 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/execution_plans/sampler.rs b/src/execution_plans/sampler.rs index 65e93ff3..50708c24 100644 --- a/src/execution_plans/sampler.rs +++ b/src/execution_plans/sampler.rs @@ -179,8 +179,7 @@ impl PartitionSampler { let Some(first_batch) = input_stream.next().await else { let _ = sampling_tx.send(pb::LoadInfo { partition: partition_idx as u64, - eos: true, - ..Default::default() + bytes_per_second: 0, }); return Ok(futures::stream::empty().boxed()); }; @@ -194,8 +193,7 @@ impl PartitionSampler { let Some(second_batch) = input_stream.next().await else { let _ = sampling_tx.send(pb::LoadInfo { partition: partition_idx as u64, - eos: true, - ..Default::default() + bytes_per_second: 0, }); return Ok(futures::stream::iter([Ok(first_batch)]).boxed()); }; @@ -213,7 +211,6 @@ impl PartitionSampler { let _ = sampling_tx.send(pb::LoadInfo { partition: partition_idx as u64, bytes_per_second, - eos: false, }); Ok(futures::stream::iter([Ok(first_batch), Ok(second_batch)]) diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index b62af68c..137fbe22 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -56,8 +56,6 @@ pub struct LoadInfo { pub partition: u64, #[prost(uint64, tag = "2")] pub bytes_per_second: u64, - #[prost(bool, tag = "3")] - pub eos: bool, } #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct GetWorkerInfoRequest {} diff --git a/src/worker/worker.proto b/src/worker/worker.proto index edcbd091..b9193e44 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -50,7 +50,6 @@ message PreOrderTaskMetrics { message LoadInfo { uint64 partition = 1; uint64 bytes_per_second = 2; - bool eos = 3; } message GetWorkerInfoRequest {} From 689731d7259dcb6f846e7b2f5f95dea6f597f0fd Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 18 May 2026 12:18:49 +0200 Subject: [PATCH 6/7] Add `load_info_sent_at` metric to SamplerExec --- src/execution_plans/sampler.rs | 40 ++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/src/execution_plans/sampler.rs b/src/execution_plans/sampler.rs index 50708c24..f03c7f20 100644 --- a/src/execution_plans/sampler.rs +++ b/src/execution_plans/sampler.rs @@ -38,11 +38,16 @@ pub(crate) struct SamplerExecMetrics { /// `execute()` is invoked on this partition (when the consumer attaches). kickoff_to_execution_p50: P50LatencyMetric, kickoff_to_execution_max: MaxLatencyMetric, - /// Time from the first `LoadInfo` message being emitted until `execute()` is invoked. + /// Time from the first batch is recived until `execute()` is invoked. /// Measures how long the coordinator sat between seeing the first sample and starting /// the consumer. first_batch_to_execution_p50: P50LatencyMetric, first_batch_to_execution_max: MaxLatencyMetric, + /// Time from the first `LoadInfo` message being emitted until `execute()` is invoked. + /// Measures how long the coordinator sat between seeing the first sample and starting + /// the consumer. + load_info_sent_to_execution_p50: P50LatencyMetric, + load_info_sent_to_execution_max: MaxLatencyMetric, /// Peak memory buffered by any partition sampler during the sampling phase. max_mem_used: Gauge, /// Bytes per second flowing through the sampler node. @@ -51,15 +56,14 @@ pub(crate) struct SamplerExecMetrics { impl SamplerExecMetrics { fn new(metric_set: &ExecutionPlanMetricsSet) -> Self { + let bdr = || MetricBuilder::new(metric_set); Self { - kickoff_to_execution_p50: MetricBuilder::new(metric_set) - .p50_latency("kickoff_to_execution_p50"), - kickoff_to_execution_max: MetricBuilder::new(metric_set) - .max_latency("kickoff_to_execution_max"), - first_batch_to_execution_p50: MetricBuilder::new(metric_set) - .p50_latency("first_batch_to_execution_p50"), - first_batch_to_execution_max: MetricBuilder::new(metric_set) - .max_latency("first_batch_to_execution_max"), + kickoff_to_execution_p50: bdr().p50_latency("kickoff_to_execution_p50"), + kickoff_to_execution_max: bdr().max_latency("kickoff_to_execution_max"), + first_batch_to_execution_p50: bdr().p50_latency("first_batch_to_execution_p50"), + first_batch_to_execution_max: bdr().max_latency("first_batch_to_execution_max"), + load_info_sent_to_execution_p50: bdr().p50_latency("load_info_sent_to_execution_p50"), + load_info_sent_to_execution_max: bdr().max_latency("load_info_sent_to_execution_max"), max_mem_used: MetricBuilder::new(metric_set).global_gauge("max_mem_used"), bytes_per_sec: MetricBuilder::new(metric_set).bytes_counter("bytes_per_sec"), } @@ -81,6 +85,7 @@ impl SamplerExec { metrics: metrics.clone(), kick_off_at: Arc::new(OnceLock::new()), first_batch_at: Arc::new(OnceLock::new()), + load_info_sent_at: Arc::new(OnceLock::new()), }); } Self { @@ -122,6 +127,10 @@ pub(crate) struct PartitionSampler { /// Set the first time the producer task emits a `LoadInfo`. Used at `execute()` time /// to record the gap between the first sample and the consumer starting. first_batch_at: Arc>, + /// Set immediately after `sampling_tx.send()` succeeds. Used to measure the full + /// round-trip: LoadInfo sent → coordinator collects votes → downstream plan dispatched + /// → consumer calls execute(). + load_info_sent_at: Arc>, } impl Debug for PartitionSampler { @@ -148,6 +157,15 @@ impl PartitionSampler { .first_batch_to_execution_max .add_duration(delay); } + if let Some(t) = self.load_info_sent_at.get() { + let delay = now.saturating_duration_since(*t); + self.metrics + .load_info_sent_to_execution_p50 + .add_duration(delay); + self.metrics + .load_info_sent_to_execution_max + .add_duration(delay); + } self.stream.lock().unwrap().take() } @@ -170,6 +188,7 @@ impl PartitionSampler { let max_mem_used_metric = self.metrics.max_mem_used.clone(); let bytes_per_sec_metric = self.metrics.bytes_per_sec.clone(); let first_batch_at = Arc::clone(&self.first_batch_at); + let load_info_sent_at = Arc::clone(&self.load_info_sent_at); // Execute the input synchronously so any setup error surfaces before we // spawn the producer task. @@ -181,6 +200,7 @@ impl PartitionSampler { partition: partition_idx as u64, bytes_per_second: 0, }); + let _ = load_info_sent_at.set(Instant::now()); return Ok(futures::stream::empty().boxed()); }; let first_batch = first_batch?; @@ -195,6 +215,7 @@ impl PartitionSampler { partition: partition_idx as u64, bytes_per_second: 0, }); + let _ = load_info_sent_at.set(Instant::now()); return Ok(futures::stream::iter([Ok(first_batch)]).boxed()); }; let second_batch = second_batch?; @@ -212,6 +233,7 @@ impl PartitionSampler { partition: partition_idx as u64, bytes_per_second, }); + let _ = load_info_sent_at.set(Instant::now()); Ok(futures::stream::iter([Ok(first_batch), Ok(second_batch)]) .chain(input_stream) From 6be38b4e6fb085605016beab0e25ab6c9423a2c5 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Mon, 18 May 2026 12:19:03 +0200 Subject: [PATCH 7/7] Improve partitions remaining calculation --- src/coordinator/prepare_dynamic_plan.rs | 32 ++-- src/worker/generated/worker.rs | 187 +++++++++--------------- 2 files changed, 88 insertions(+), 131 deletions(-) diff --git a/src/coordinator/prepare_dynamic_plan.rs b/src/coordinator/prepare_dynamic_plan.rs index f19af41f..e37c9b37 100644 --- a/src/coordinator/prepare_dynamic_plan.rs +++ b/src/coordinator/prepare_dynamic_plan.rs @@ -138,7 +138,7 @@ pub(super) async fn prepare_dynamic_plan( let load_info_stream = futures::stream::select_all(load_info_rxs); let partitions_per_task = nb.properties().partitioning.partition_count(); - let partitions_remaining = partitions_per_task * input_stage.tasks; + let partitions_remaining = vec![partitions_per_task; input_stage.tasks]; let bytes_per_partition_per_second = d_cfg.bytes_per_partition_per_second; Ok(async move { @@ -240,32 +240,42 @@ fn get_child_stages_urls( /// no slice could compute a usable velocity. async fn calculate_necessary_partitions( mut load_info_stream: impl Stream + Unpin, - mut partitions_remaining: usize, + mut partitions_remaining: Vec, bytes_per_partition_per_second: usize, ) -> usize { /// Minimum number of slices that must reach the "done" state before voting. A slice is /// done when it signals `eos`, `max_memory_reached`, or hits `MAX_MESSAGES_PER_SLICE`. const TARGET_DONE_PARTITIONS: usize = 3; + let mut partitions_done = vec![0; partitions_remaining.len()]; let mut votes = vec![]; - while let Some((_task_idx, load_info)) = load_info_stream.next().await { - partitions_remaining -= 1; - if load_info.bytes_per_second == 0 { - // The partition reporting this load info was empty. - continue; - } + while let Some((task_idx, load_info)) = load_info_stream.next().await { + partitions_remaining[task_idx] -= 1; + partitions_done[task_idx] += 1; votes.push(load_info); - if votes.len() >= TARGET_DONE_PARTITIONS || partitions_remaining == 0 { + + let mut finished = true; + for p in &partitions_done { + if *p < TARGET_DONE_PARTITIONS { + finished = false; + break; + } + } + + if finished || partitions_remaining.iter().all(|p| *p == 0) { break; } } - if votes.is_empty() { + // Only non-zero votes carry throughput signal; zero-rate partitions (single-batch or + // empty) count toward the early-exit threshold above but must not dilute the average. + let rated_votes: Vec<_> = votes.iter().filter(|v| v.bytes_per_second > 0).collect(); + if rated_votes.is_empty() { return 1; } let avg_bytes_per_second_per_partition = - votes.iter().map(|v| v.bytes_per_second).sum::() / votes.len() as u64; + rated_votes.iter().map(|v| v.bytes_per_second).sum::() / rated_votes.len() as u64; (avg_bytes_per_second_per_partition as usize).div_ceil(bytes_per_partition_per_second) } diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index 137fbe22..300b62ea 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -81,9 +81,8 @@ pub struct SetPlanRequest { /// /// If no WorkUnitFeedExec nodes are present in the plan, this should be empty. #[prost(message, repeated, tag = "4")] - pub work_unit_feed_declarations: ::prost::alloc::vec::Vec< - set_plan_request::WorkUnitFeedDeclaration, - >, + pub work_unit_feed_declarations: + ::prost::alloc::vec::Vec, /// The worker URL to which this message will go. The receiving worker will use this information to identify /// itself, and avoid further gRPC calls in case it needs to call itself for executing remote tasks. #[prost(string, tag = "5")] @@ -375,10 +374,10 @@ pub mod worker_service_client { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value, + clippy::let_unit_value )] - use tonic::codegen::*; use tonic::codegen::http::Uri; + use tonic::codegen::*; #[derive(Debug, Clone)] pub struct WorkerServiceClient { inner: tonic::client::Grpc, @@ -417,14 +416,13 @@ pub mod worker_service_client { F: tonic::service::Interceptor, T::ResponseBody: Default, T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, + http::Request, + Response = http::Response< + >::ResponseBody, + >, >, - >, - , - >>::Error: Into + std::marker::Send + std::marker::Sync, + >>::Error: + Into + std::marker::Send + std::marker::Sync, { WorkerServiceClient::new(InterceptedService::new(inner, interceptor)) } @@ -464,28 +462,22 @@ pub mod worker_service_client { /// per task. pub async fn coordinator_channel( &mut self, - request: impl tonic::IntoStreamingRequest< - Message = super::CoordinatorToWorkerMsg, - >, + request: impl tonic::IntoStreamingRequest, ) -> std::result::Result< tonic::Response>, tonic::Status, > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/worker.WorkerService/CoordinatorChannel", - ); + let path = + http::uri::PathAndQuery::from_static("/worker.WorkerService/CoordinatorChannel"); let mut req = request.into_streaming_request(); - req.extensions_mut() - .insert(GrpcMethod::new("worker.WorkerService", "CoordinatorChannel")); + req.extensions_mut().insert(GrpcMethod::new( + "worker.WorkerService", + "CoordinatorChannel", + )); self.inner.streaming(req, path, codec).await } /// Executes the requested partition range of a subplan previously sent by the coordinator channel. @@ -496,18 +488,11 @@ pub mod worker_service_client { tonic::Response>, tonic::Status, > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/worker.WorkerService/ExecuteTask", - ); + let path = http::uri::PathAndQuery::from_static("/worker.WorkerService/ExecuteTask"); let mut req = request.into_request(); req.extensions_mut() .insert(GrpcMethod::new("worker.WorkerService", "ExecuteTask")); @@ -517,22 +502,13 @@ pub mod worker_service_client { pub async fn get_worker_info( &mut self, request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; let codec = tonic_prost::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/worker.WorkerService/GetWorkerInfo", - ); + let path = http::uri::PathAndQuery::from_static("/worker.WorkerService/GetWorkerInfo"); let mut req = request.into_request(); req.extensions_mut() .insert(GrpcMethod::new("worker.WorkerService", "GetWorkerInfo")); @@ -547,7 +523,7 @@ pub mod worker_service_server { dead_code, missing_docs, clippy::wildcard_imports, - clippy::let_unit_value, + clippy::let_unit_value )] use tonic::codegen::*; /// Generated trait containing gRPC methods that should be implemented for use with WorkerServiceServer. @@ -556,8 +532,7 @@ pub mod worker_service_server { /// Server streaming response type for the CoordinatorChannel method. type CoordinatorChannelStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, - > - + std::marker::Send + > + std::marker::Send + 'static; /// Establishes a bidirectional message stream between a coordinator and a worker, over which messages /// will be exchanged at any time during a query's lifetime. It's expected to be one coordinator channel @@ -565,32 +540,22 @@ pub mod worker_service_server { async fn coordinator_channel( &self, request: tonic::Request>, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; + ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the ExecuteTask method. type ExecuteTaskStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result<::arrow_flight::FlightData, tonic::Status>, - > - + std::marker::Send + > + std::marker::Send + 'static; /// Executes the requested partition range of a subplan previously sent by the coordinator channel. async fn execute_task( &self, request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; + ) -> std::result::Result, tonic::Status>; /// Returns metadata about a worker. Currently only used for worker versioning. async fn get_worker_info( &self, request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; + ) -> std::result::Result, tonic::Status>; } #[derive(Debug)] pub struct WorkerServiceServer { @@ -613,10 +578,7 @@ pub mod worker_service_server { max_encoding_message_size: None, } } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService where F: tonic::service::Interceptor, { @@ -671,16 +633,14 @@ pub mod worker_service_server { "/worker.WorkerService/CoordinatorChannel" => { #[allow(non_camel_case_types)] struct CoordinatorChannelSvc(pub Arc); - impl< - T: WorkerService, - > tonic::server::StreamingService - for CoordinatorChannelSvc { + impl + tonic::server::StreamingService + for CoordinatorChannelSvc + { type Response = super::WorkerToCoordinatorMsg; type ResponseStream = T::CoordinatorChannelStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = + BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request< @@ -689,8 +649,7 @@ pub mod worker_service_server { ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - ::coordinator_channel(&inner, request) - .await + ::coordinator_channel(&inner, request).await }; Box::pin(fut) } @@ -720,16 +679,14 @@ pub mod worker_service_server { "/worker.WorkerService/ExecuteTask" => { #[allow(non_camel_case_types)] struct ExecuteTaskSvc(pub Arc); - impl< - T: WorkerService, - > tonic::server::ServerStreamingService - for ExecuteTaskSvc { + impl + tonic::server::ServerStreamingService + for ExecuteTaskSvc + { type Response = ::arrow_flight::FlightData; type ResponseStream = T::ExecuteTaskStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = + BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -766,15 +723,11 @@ pub mod worker_service_server { "/worker.WorkerService/GetWorkerInfo" => { #[allow(non_camel_case_types)] struct GetWorkerInfoSvc(pub Arc); - impl< - T: WorkerService, - > tonic::server::UnaryService - for GetWorkerInfoSvc { + impl tonic::server::UnaryService + for GetWorkerInfoSvc + { type Response = super::GetWorkerInfoResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; + type Future = BoxFuture, tonic::Status>; fn call( &mut self, request: tonic::Request, @@ -808,25 +761,19 @@ pub mod worker_service_server { }; Box::pin(fut) } - _ => { - Box::pin(async move { - let mut response = http::Response::new( - tonic::body::Body::default(), - ); - let headers = response.headers_mut(); - headers - .insert( - tonic::Status::GRPC_STATUS, - (tonic::Code::Unimplemented as i32).into(), - ); - headers - .insert( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ); - Ok(response) - }) - } + _ => Box::pin(async move { + let mut response = http::Response::new(tonic::body::Body::default()); + let headers = response.headers_mut(); + headers.insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers.insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }), } } }