diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index 5e76819810..9e97ef2226 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -29,18 +29,48 @@ import org.apache.arrow.vector.ValueVector; import org.apache.spark.TaskContext; import org.apache.spark.comet.CometTaskContextShim; +import org.apache.spark.util.TaskCompletionListener; /** * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method * pattern used by CometScalarSubquery so the native side can dispatch via * call_static_method_unchecked. + * + *

Cache invariants: + * + *

    + *
  1. For each live Spark task attempt there is at most one {@link CometUDF} instance per class + * name. + *
  2. A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated + * it. Two task attempts observing the same class name receive distinct instances. + *
  3. At any instant at most one thread is inside {@code evaluate()} for a given {@code + * taskAttemptId}. This follows from Spark executing one native future per partition and Tokio + * polling one future per worker at a time. + *
  4. All instances for a task are dropped by the {@link TaskCompletionListener} registered on + * the first cache miss for that task. No cache entry outlives its task. + *
  5. When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback + * key {@code -1L} is used; that bucket is never evicted because no task-completion event will + * fire. + *
+ * + *

Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio + * work-stealing: on the scan-free execution path the same Spark task can be polled by different + * Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The + * task attempt ID is stable for the life of the task regardless of which worker is polling. */ public class CometUdfBridge { - // Process-wide cache of UDF instances keyed by class name. CometUDF - // implementations are required to be stateless (see CometUDF), so a - // single shared instance per class is safe across native worker threads. - private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>(); + /** + * Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or + * {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF + * class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener} + * registered on the first cache miss per task. + */ + private static final ConcurrentHashMap> INSTANCES = + new ConcurrentHashMap<>(); + + /** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */ + private static final long NO_TASK_ID = -1L; /** * Called from native via JNI. @@ -58,7 +88,9 @@ public class CometUdfBridge { * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}. * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext - * left on a worker by a previous task. + * left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache, + * so a UDF holding per-task state in fields sees a consistent instance for every call within + * the task regardless of which Tokio worker is polling. */ public static void evaluate( String udfClassName, @@ -68,16 +100,33 @@ public static void evaluate( long outSchemaPtr, int numRows, TaskContext taskContext) { - // Save-and-restore rather than only-install-if-null: the propagated context is the ground - // truth for this call. Any value already on the thread is either (a) the same object on a - // Spark task thread, or (b) stale from a prior task on a reused Tokio worker. + assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty"; + assert inputArrayPtrs != null && inputSchemaPtrs != null + : "input pointer arrays must be non-null"; + assert inputArrayPtrs.length == inputSchemaPtrs.length + : "input array pointer count must equal schema pointer count"; + assert numRows >= 0 : "numRows must be non-negative"; + assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer"; + assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer"; + + // Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the + // ground truth for this call. Any value already on the thread is either (a) the same object + // on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker. TaskContext prior = TaskContext.get(); if (taskContext != null) { CometTaskContextShim.set(taskContext); + assert TaskContext.get() == taskContext + : "TaskContext install did not take effect on this thread"; } try { evaluateInternal( - udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); + udfClassName, + inputArrayPtrs, + inputSchemaPtrs, + outArrayPtr, + outSchemaPtr, + numRows, + taskContext); } finally { if (taskContext != null) { if (prior != null) { @@ -95,9 +144,34 @@ private static void evaluateInternal( long[] inputSchemaPtrs, long outArrayPtr, long outSchemaPtr, - int numRows) { - CometUDF udf = + int numRows, + TaskContext taskContext) { + long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID; + + ConcurrentHashMap perTask = INSTANCES.computeIfAbsent( + taskAttemptId, + id -> { + ConcurrentHashMap fresh = new ConcurrentHashMap<>(); + if (taskContext != null) { + // computeIfAbsent runs this lambda at most once per key, so the listener is + // registered exactly once per task attempt. + taskContext.addTaskCompletionListener( + (TaskCompletionListener) + ctx -> { + ConcurrentHashMap removed = INSTANCES.remove(id); + assert removed != null + : "task-completion listener fired but cache already removed " + + "entry for task " + + id; + }); + } + return fresh; + }); + assert perTask != null : "per-task cache must be non-null after computeIfAbsent"; + + CometUDF udf = + perTask.computeIfAbsent( udfClassName, name -> { try { @@ -113,6 +187,7 @@ private static void evaluateInternal( throw new RuntimeException("Failed to instantiate CometUDF: " + name, e); } }); + assert udf != null : "reflective instantiation returned null for " + udfClassName; BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 5b6652d90a..6b435c4064 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -30,12 +30,18 @@ import org.apache.arrow.vector.ValueVector * - The returned vector's length must match `numRows`. * * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. - * UDFs that always have at least one batch-length input can derive length from the inputs and - * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF) - * need `numRows` to know how many rows to produce. + * UDFs that always have at least one batch-length input can read length from it and ignore + * `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the + * codegen dispatcher) need `numRows` to know how many rows to produce. * - * Implementations must have a public no-arg constructor and must be stateless: a single instance - * per class is cached and shared across native worker threads for the lifetime of the JVM. + * Implementations must have a public no-arg constructor. A fresh instance is created per Spark + * task attempt per class and reused for every call within that task. Instances may hold per-task + * state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task + * completion. Do not hold state that must persist across tasks. + * + * At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future + * per partition and Tokio polls one future per worker, so the per-task instance is never touched + * concurrently even if the task's future migrates between Tokio workers across batches. */ trait CometUDF { def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector