Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ jobs:
org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite
org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite
org.apache.comet.objectstore.NativeConfigSuite
org.apache.spark.comet.udf.CometUdfBridgeSuite
- name: "expressions"
value: |
org.apache.comet.CometExpressionSuite
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ jobs:
org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite
org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite
org.apache.comet.objectstore.NativeConfigSuite
org.apache.spark.comet.udf.CometUdfBridgeSuite
- name: "expressions"
value: |
org.apache.comet.CometExpressionSuite
Expand Down
48 changes: 38 additions & 10 deletions common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.spark.TaskContext;
import org.apache.spark.comet.CometTaskContextShim;

/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
Expand All @@ -48,13 +50,45 @@ public class CometUdfBridge {
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
* @param numRows row count of the current batch. Mirrors DataFusion's {@code
* ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
* zero-arg non-deterministic ScalaUDF) ever sees.
* @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
* {@code null} outside a Spark task. Installed as the thread-local for the duration of the
* call when the current thread has none, so partition-sensitive built-ins ({@code Rand},
* {@code Uuid}, {@code MonotonicallyIncreasingID}) work from Tokio workers. Cleared in {@code
* finally} to avoid leaking across worker reuse.
*/
public static void evaluate(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr) {
long outSchemaPtr,
int numRows,
TaskContext taskContext) {
boolean installedTaskContext = false;
if (taskContext != null && TaskContext.get() == null) {
Comment thread
mbutrovich marked this conversation as resolved.
Outdated
CometTaskContextShim.set(taskContext);
installedTaskContext = true;
}
try {
evaluateInternal(
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
} finally {
if (installedTaskContext) {
CometTaskContextShim.unset();
}
}
}

private static void evaluateInternal(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr,
int numRows) {
CometUDF udf =
INSTANCES.computeIfAbsent(
udfClassName,
Expand Down Expand Up @@ -84,23 +118,17 @@ public static void evaluate(
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
}

result = udf.evaluate(inputs);
result = udf.evaluate(inputs, numRows);
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
}
// Result length must match the longest input. Scalar (length-1) inputs
// are allowed to be shorter, but a vector input bounds the output.
int expectedLen = 0;
for (ValueVector v : inputs) {
expectedLen = Math.max(expectedLen, v.getValueCount());
}
if (result.getValueCount() != expectedLen) {
if (result.getValueCount() != numRows) {
throw new RuntimeException(
"CometUDF.evaluate() returned "
+ result.getValueCount()
+ " rows, expected "
+ expectedLen);
+ numRows);
}
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
Expand Down
9 changes: 7 additions & 2 deletions common/src/main/scala/org/apache/comet/udf/CometUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector
*
* - Vector arguments arrive at the row count of the current batch.
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
* - The returned vector's length must match the longest input.
* - The returned vector's length must match `numRows`.
*
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
* UDFs that always have at least one batch-length input can derive length from the inputs and
* ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
* need `numRows` to know how many rows to produce.
*
* Implementations must have a public no-arg constructor and must be stateless: a single instance
* per class is cached and shared across native worker threads for the lifetime of the JVM.
*/
trait CometUDF {
def evaluate(inputs: Array[ValueVector]): ValueVector
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.comet

import org.apache.spark.TaskContext

/**
* Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`.
*
* Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are
* reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`.
* The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a
* Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive
* built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's
* `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the
* protected methods, and exposes plain public forwarders the bridge (which lives in
* `org.apache.comet.udf`) can use.
*/
object CometTaskContextShim {

def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext)

def unset(): Unit = TaskContext.unset()
}
21 changes: 20 additions & 1 deletion native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ struct ExecutionContext {
pub tracing_memory_metric_name: String,
/// Pre-computed tracing event name for executePlan calls
pub tracing_event_name: String,
/// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time.
/// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it
/// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no
/// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is
/// cheap to clone; the underlying `Global<JObject>` releases its JNI global ref on drop
/// via `jni`'s `Drop` impl.
pub task_context: Option<Arc<Global<JObject<'static>>>>,
}

/// Accept serialized query plan and return the address of the native query plan.
Expand All @@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
task_attempt_id: jlong,
task_cpus: jlong,
key_unwrapper_obj: JObject,
task_context_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |env| {
// Deserialize Spark configs
Expand Down Expand Up @@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
String::new()
};

// Capture the driving Spark task's TaskContext as a JNI global reference when
// non-null. The `Arc<Global<JObject>>` releases its global ref on drop, so cleanup
// is automatic when the ExecutionContext drops.
let task_context = if !task_context_obj.is_null() {
Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
} else {
None
};

let exec_context = Box::new(ExecutionContext {
id,
task_attempt_id,
Expand All @@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
"thread_{rust_thread_id}_comet_memory_reserved"
),
tracing_event_name,
task_context,
});

Ok(Box::into_raw(exec_context) as i64)
Expand Down Expand Up @@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner =
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
.with_exec_id(exec_context_id);
.with_exec_id(exec_context_id)
.with_task_context(exec_context.task_context.clone());
let (scans, shuffle_scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
Expand Down
26 changes: 19 additions & 7 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ pub struct PhysicalPlanner {
partition: i32,
session_ctx: Arc<SessionContext>,
query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
/// Captured at `createPlan` time on `ExecutionContext`; see that struct for the
/// propagation rationale. `None` when no driving Spark task is available.
task_context: Option<Arc<Global<JObject<'static>>>>,
}

impl Default for PhysicalPlanner {
Expand All @@ -198,16 +201,24 @@ impl PhysicalPlanner {
session_ctx,
partition,
query_context_registry: datafusion_comet_spark_expr::create_query_context_map(),
task_context: None,
}
}

pub fn with_exec_id(self, exec_context_id: i64) -> Self {
Self {
exec_context_id,
partition: self.partition,
session_ctx: Arc::clone(&self.session_ctx),
query_context_registry: Arc::clone(&self.query_context_registry),
}
pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
self.exec_context_id = exec_context_id;
self
}

/// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned
/// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as
/// the thread-local on the Tokio worker driving the UDF.
pub fn with_task_context(
mut self,
task_context: Option<Arc<Global<JObject<'static>>>>,
) -> Self {
self.task_context = task_context;
self
}

/// Return session context of this planner.
Expand Down Expand Up @@ -735,6 +746,7 @@ impl PhysicalPlanner {
args,
return_type,
udf.return_nullable,
self.task_context.clone(),
)))
}
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
Expand Down
2 changes: 1 addition & 1 deletion native/jni-bridge/src/comet_udf_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> {
method_evaluate: env.get_static_method_id(
JNIString::new(Self::JVM_CLASS),
jni::jni_str!("evaluate"),
jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"),
)?,
method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
class,
Expand Down
24 changes: 22 additions & 2 deletions native/spark-expr/src/jvm_udf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr;

use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError};
use datafusion_comet_jni_bridge::JVMClasses;
use jni::objects::{JObject, JValue};
use jni::objects::{Global, JObject, JValue};

/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI.
/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`.
Expand All @@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
/// Captured at `createPlan` time and threaded here by the planner. Passed through the
/// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
/// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
/// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
/// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
/// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
/// returns in place.
task_context: Option<Arc<Global<JObject<'static>>>>,
}

impl JvmScalarUdfExpr {
Expand All @@ -49,12 +57,14 @@ impl JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
task_context: Option<Arc<Global<JObject<'static>>>>,
) -> Self {
Self {
class_name,
args,
return_type,
return_nullable,
task_context,
}
}
}
Expand Down Expand Up @@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr {
.set_region(env, 0, &in_sch_ptrs)
.map_err(|e| CometError::JNI { source: e })?;

// Call CometUdfBridge.evaluate(String, long[], long[], long, long)
// Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
// leaves the worker thread's current TaskContext.get() in place. The borrow must
// outlive `call_static_method_unchecked`.
let null_task_context = JObject::null();
let task_context_ref: &JObject = match &self.task_context {
Some(gref) => gref.as_obj(),
None => &null_task_context,
};
let ret = unsafe {
env.call_static_method_unchecked(
&bridge.class,
Expand All @@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr {
JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(),
JValue::Long(out_arr_ptr).as_jni(),
JValue::Long(out_sch_ptr).as_jni(),
JValue::Int(batch.num_rows() as i32).as_jni(),
JValue::Object(task_context_ref).as_jni(),
],
)
};
Expand Down Expand Up @@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
children,
self.return_type.clone(),
self.return_nullable,
self.task_context.clone(),
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ class CometExecIterator(
memoryConfig.memoryLimitPerTask,
taskAttemptId,
taskCPUs,
keyUnwrapper)
keyUnwrapper,
// Propagated to Tokio workers running JVM UDFs so they see this Spark task's
// TaskContext. See CometUdfBridge.evaluate.
TaskContext.get())
}

private var nextBatch: Option[ColumnarBatch] = None
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet

import java.nio.ByteBuffer

import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.{CometTaskMemoryManager, TaskContext}
import org.apache.spark.sql.comet.CometMetricNode

import org.apache.comet.parquet.CometFileKeyUnwrapper
Expand Down Expand Up @@ -69,7 +69,8 @@ class Native extends NativeBase {
memoryLimitPerTask: Long,
taskAttemptId: Long,
taskCPUs: Long,
keyUnwrapper: CometFileKeyUnwrapper): Long
keyUnwrapper: CometFileKeyUnwrapper,
taskContext: TaskContext): Long
// scalastyle:on

/**
Expand Down
Loading
Loading