Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,48 @@
import org.apache.arrow.vector.ValueVector;
import org.apache.spark.TaskContext;
import org.apache.spark.comet.CometTaskContextShim;
import org.apache.spark.util.TaskCompletionListener;

/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
* pattern used by CometScalarSubquery so the native side can dispatch via
* call_static_method_unchecked.
*
* <p>Cache invariants:
*
* <ol>
* <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
* name.
* <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
* it. Two task attempts observing the same class name receive distinct instances.
* <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
* taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
* polling one future per worker at a time.
* <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
* the first cache miss for that task. No cache entry outlives its task.
* <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
* key {@code -1L} is used; that bucket is never evicted because no task-completion event will
* fire.
* </ol>
*
* <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
* work-stealing: on the scan-free execution path the same Spark task can be polled by different
* Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
* task attempt ID is stable for the life of the task regardless of which worker is polling.
*/
public class CometUdfBridge {

// Process-wide cache of UDF instances keyed by class name. CometUDF
// implementations are required to be stateless (see CometUDF), so a
// single shared instance per class is safe across native worker threads.
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
/**
* Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
* {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
* class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
* registered on the first cache miss per task.
*/
private static final ConcurrentHashMap<Long, ConcurrentHashMap<String, CometUDF>> INSTANCES =
new ConcurrentHashMap<>();

/** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
private static final long NO_TASK_ID = -1L;

/**
* Called from native via JNI.
Expand All @@ -58,7 +88,9 @@ public class CometUdfBridge {
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
* left on a worker by a previous task.
* left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache,
* so a UDF holding per-task state in fields sees a consistent instance for every call within
* the task regardless of which Tokio worker is polling.
*/
public static void evaluate(
String udfClassName,
Expand All @@ -68,16 +100,33 @@ public static void evaluate(
long outSchemaPtr,
int numRows,
TaskContext taskContext) {
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
// truth for this call. Any value already on the thread is either (a) the same object on a
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty";
assert inputArrayPtrs != null && inputSchemaPtrs != null
: "input pointer arrays must be non-null";
assert inputArrayPtrs.length == inputSchemaPtrs.length
: "input array pointer count must equal schema pointer count";
assert numRows >= 0 : "numRows must be non-negative";
assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer";
assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer";

// Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the
// ground truth for this call. Any value already on the thread is either (a) the same object
// on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
TaskContext prior = TaskContext.get();
if (taskContext != null) {
CometTaskContextShim.set(taskContext);
assert TaskContext.get() == taskContext
: "TaskContext install did not take effect on this thread";
}
try {
evaluateInternal(
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
udfClassName,
inputArrayPtrs,
inputSchemaPtrs,
outArrayPtr,
outSchemaPtr,
numRows,
taskContext);
} finally {
if (taskContext != null) {
if (prior != null) {
Expand All @@ -95,9 +144,34 @@ private static void evaluateInternal(
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr,
int numRows) {
CometUDF udf =
int numRows,
TaskContext taskContext) {
long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID;

ConcurrentHashMap<String, CometUDF> perTask =
INSTANCES.computeIfAbsent(
taskAttemptId,
id -> {
ConcurrentHashMap<String, CometUDF> fresh = new ConcurrentHashMap<>();
if (taskContext != null) {
// computeIfAbsent runs this lambda at most once per key, so the listener is
// registered exactly once per task attempt.
taskContext.addTaskCompletionListener(
(TaskCompletionListener)
ctx -> {
ConcurrentHashMap<String, CometUDF> removed = INSTANCES.remove(id);
assert removed != null
: "task-completion listener fired but cache already removed "
+ "entry for task "
+ id;
});
}
return fresh;
});
assert perTask != null : "per-task cache must be non-null after computeIfAbsent";

CometUDF udf =
perTask.computeIfAbsent(
udfClassName,
name -> {
try {
Expand All @@ -113,6 +187,7 @@ private static void evaluateInternal(
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
}
});
assert udf != null : "reflective instantiation returned null for " + udfClassName;

BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();

Expand Down
16 changes: 11 additions & 5 deletions common/src/main/scala/org/apache/comet/udf/CometUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ import org.apache.arrow.vector.ValueVector
* - The returned vector's length must match `numRows`.
*
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
* UDFs that always have at least one batch-length input can derive length from the inputs and
* ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
* need `numRows` to know how many rows to produce.
* UDFs that always have at least one batch-length input can read length from it and ignore
* `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the
* codegen dispatcher) need `numRows` to know how many rows to produce.
*
* Implementations must have a public no-arg constructor and must be stateless: a single instance
* per class is cached and shared across native worker threads for the lifetime of the JVM.
* Implementations must have a public no-arg constructor. A fresh instance is created per Spark
* task attempt per class and reused for every call within that task. Instances may hold per-task
* state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task
* completion. Do not hold state that must persist across tasks.
*
* At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future
* per partition and Tokio polls one future per worker, so the per-task instance is never touched
* concurrently even if the task's future migrates between Tokio workers across batches.
*/
trait CometUDF {
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
Expand Down
Loading