diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
index 5e76819810..9e97ef2226 100644
--- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
+++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
@@ -29,18 +29,48 @@
import org.apache.arrow.vector.ValueVector;
import org.apache.spark.TaskContext;
import org.apache.spark.comet.CometTaskContextShim;
+import org.apache.spark.util.TaskCompletionListener;
/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
* pattern used by CometScalarSubquery so the native side can dispatch via
* call_static_method_unchecked.
+ *
+ *
Cache invariants:
+ *
+ *
+ * - For each live Spark task attempt there is at most one {@link CometUDF} instance per class
+ * name.
+ *
- A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
+ * it. Two task attempts observing the same class name receive distinct instances.
+ *
- At any instant at most one thread is inside {@code evaluate()} for a given {@code
+ * taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
+ * polling one future per worker at a time.
+ *
- All instances for a task are dropped by the {@link TaskCompletionListener} registered on
+ * the first cache miss for that task. No cache entry outlives its task.
+ *
- When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
+ * key {@code -1L} is used; that bucket is never evicted because no task-completion event will
+ * fire.
+ *
+ *
+ * Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
+ * work-stealing: on the scan-free execution path the same Spark task can be polled by different
+ * Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
+ * task attempt ID is stable for the life of the task regardless of which worker is polling.
*/
public class CometUdfBridge {
- // Process-wide cache of UDF instances keyed by class name. CometUDF
- // implementations are required to be stateless (see CometUDF), so a
- // single shared instance per class is safe across native worker threads.
- private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>();
+ /**
+ * Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
+ * {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
+ * class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
+ * registered on the first cache miss per task.
+ */
+ private static final ConcurrentHashMap> INSTANCES =
+ new ConcurrentHashMap<>();
+
+ /** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
+ private static final long NO_TASK_ID = -1L;
/**
* Called from native via JNI.
@@ -58,7 +88,9 @@ public class CometUdfBridge {
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
- * left on a worker by a previous task.
+ * left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache,
+ * so a UDF holding per-task state in fields sees a consistent instance for every call within
+ * the task regardless of which Tokio worker is polling.
*/
public static void evaluate(
String udfClassName,
@@ -68,16 +100,33 @@ public static void evaluate(
long outSchemaPtr,
int numRows,
TaskContext taskContext) {
- // Save-and-restore rather than only-install-if-null: the propagated context is the ground
- // truth for this call. Any value already on the thread is either (a) the same object on a
- // Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
+ assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty";
+ assert inputArrayPtrs != null && inputSchemaPtrs != null
+ : "input pointer arrays must be non-null";
+ assert inputArrayPtrs.length == inputSchemaPtrs.length
+ : "input array pointer count must equal schema pointer count";
+ assert numRows >= 0 : "numRows must be non-negative";
+ assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer";
+ assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer";
+
+ // Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the
+ // ground truth for this call. Any value already on the thread is either (a) the same object
+ // on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
TaskContext prior = TaskContext.get();
if (taskContext != null) {
CometTaskContextShim.set(taskContext);
+ assert TaskContext.get() == taskContext
+ : "TaskContext install did not take effect on this thread";
}
try {
evaluateInternal(
- udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
+ udfClassName,
+ inputArrayPtrs,
+ inputSchemaPtrs,
+ outArrayPtr,
+ outSchemaPtr,
+ numRows,
+ taskContext);
} finally {
if (taskContext != null) {
if (prior != null) {
@@ -95,9 +144,34 @@ private static void evaluateInternal(
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr,
- int numRows) {
- CometUDF udf =
+ int numRows,
+ TaskContext taskContext) {
+ long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID;
+
+ ConcurrentHashMap perTask =
INSTANCES.computeIfAbsent(
+ taskAttemptId,
+ id -> {
+ ConcurrentHashMap fresh = new ConcurrentHashMap<>();
+ if (taskContext != null) {
+ // computeIfAbsent runs this lambda at most once per key, so the listener is
+ // registered exactly once per task attempt.
+ taskContext.addTaskCompletionListener(
+ (TaskCompletionListener)
+ ctx -> {
+ ConcurrentHashMap removed = INSTANCES.remove(id);
+ assert removed != null
+ : "task-completion listener fired but cache already removed "
+ + "entry for task "
+ + id;
+ });
+ }
+ return fresh;
+ });
+ assert perTask != null : "per-task cache must be non-null after computeIfAbsent";
+
+ CometUDF udf =
+ perTask.computeIfAbsent(
udfClassName,
name -> {
try {
@@ -113,6 +187,7 @@ private static void evaluateInternal(
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
}
});
+ assert udf != null : "reflective instantiation returned null for " + udfClassName;
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
index 5b6652d90a..6b435c4064 100644
--- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
+++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
@@ -30,12 +30,18 @@ import org.apache.arrow.vector.ValueVector
* - The returned vector's length must match `numRows`.
*
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
- * UDFs that always have at least one batch-length input can derive length from the inputs and
- * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
- * need `numRows` to know how many rows to produce.
+ * UDFs that always have at least one batch-length input can read length from it and ignore
+ * `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the
+ * codegen dispatcher) need `numRows` to know how many rows to produce.
*
- * Implementations must have a public no-arg constructor and must be stateless: a single instance
- * per class is cached and shared across native worker threads for the lifetime of the JVM.
+ * Implementations must have a public no-arg constructor. A fresh instance is created per Spark
+ * task attempt per class and reused for every call within that task. Instances may hold per-task
+ * state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task
+ * completion. Do not hold state that must persist across tasks.
+ *
+ * At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future
+ * per partition and Tokio polls one future per worker, so the per-task instance is never touched
+ * concurrently even if the task's future migrates between Tokio workers across batches.
*/
trait CometUDF {
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector