Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1746bcc
feat: Arrow-direct codegen dispatcher for Spark expressions and Scala…
mbutrovich May 8, 2026
08d6b78
prettier, add new suites to CI checks.
mbutrovich May 8, 2026
557752e
make format, fix shims for 4.0+
mbutrovich May 8, 2026
896f61f
make format, fix shims for 4.0+
mbutrovich May 8, 2026
a82e160
Merge branch 'main' into codegen_scala_udf
mbutrovich May 8, 2026
2a158f4
strengthen tests for composed expressions
mbutrovich May 8, 2026
654bbad
make format, again.
mbutrovich May 8, 2026
10df7e0
fix pr_benchmark_check.yml
mbutrovich May 8, 2026
7afe69f
fix arrow shading issue in CI.
mbutrovich May 8, 2026
0dc5855
fix Spark 4.0 collation expression shim
mbutrovich May 8, 2026
43a7b0c
apply common subexpression elimination, add tests for subqueries in UDFs
mbutrovich May 8, 2026
9640897
make format
mbutrovich May 8, 2026
f0c8296
decimal fast path. document 64KB limitation right now
mbutrovich May 9, 2026
2173f40
pass through task context to get around tokio worker pool calling ove…
mbutrovich May 9, 2026
2f9585b
fix compilation on scala 2.12, fix format issue
mbutrovich May 9, 2026
582cd17
Merge branch 'main' into codegen_scala_udf
mbutrovich May 9, 2026
22f3256
decimal output, utf8 output, non-nullable output optimizations
mbutrovich May 9, 2026
7666715
optimization menu
mbutrovich May 9, 2026
0a34636
estimate binaryview and binary size
mbutrovich May 9, 2026
e94b6db
fix "CSE collapses a repeated subtree to one evaluation in the genera…
mbutrovich May 9, 2026
d0f1f27
Merge remote-tracking branch 'origin/codegen_scala_udf' into codegen_…
mbutrovich May 9, 2026
07e37ea
add some complex type support, remove #4239 code. update docs.
mbutrovich May 9, 2026
ebf77c4
split codegen input and output, basic struct WIP
mbutrovich May 9, 2026
6836c30
split massive codegen file, handle recursive nested types
mbutrovich May 9, 2026
5d91a8f
map input
mbutrovich May 9, 2026
2a28aaf
more struct support
mbutrovich May 9, 2026
0c6586a
revert some benchmark changes
mbutrovich May 9, 2026
8497fe7
cleanup part 1
mbutrovich May 10, 2026
8d703c3
cleanup part 2
mbutrovich May 10, 2026
5ec0e3f
cleanup part 3
mbutrovich May 10, 2026
a22051e
remove view support, it's dead code right now
mbutrovich May 10, 2026
421c60c
use cometplainvector part 1
mbutrovich May 10, 2026
0705dff
use cometplainvector part 2
mbutrovich May 10, 2026
9a00874
make generated class final
mbutrovich May 10, 2026
d7b43fc
clean up test names
mbutrovich May 10, 2026
034e1f5
fix format
mbutrovich May 11, 2026
317feaf
Merge branch 'main' into codegen_scala_udf
mbutrovich May 11, 2026
db1f1f2
Merge branch 'main' into codegen_scala_udf
mbutrovich May 12, 2026
caffed9
fix 2.12 mapvalues usage
mbutrovich May 12, 2026
4be8144
Remove code related to #4239.
mbutrovich May 12, 2026
6fcd81c
Merge remote-tracking branch 'apache/main' into codegen_scala_udf
mbutrovich May 14, 2026
9f8aa07
fix after merging in upstream/main.
mbutrovich May 14, 2026
17b2714
switch to taskid-keyed state for CometUDFs.
mbutrovich May 14, 2026
ff8ee79
Merge branch 'main' into codegen_scala_udf
mbutrovich May 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/pr_benchmark_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ jobs:
${{ runner.os }}-benchmark-maven-

- name: Check Scala compilation and linting
# Pin to spark-4.0 (Scala 2.13.16) because the default profile is now
# spark-4.1 / Scala 2.13.17, and semanticdb-scalac_2.13.17 is not yet
# published, which breaks `-Psemanticdb`. See pr_build_linux.yml for
# the same exclusion in the main lint matrix.
# Pinned to spark-4.0 because semanticdb-scalac_2.13.17 (spark-4.1 default)
# is not yet published, which breaks the -Psemanticdb scalafix lint.
run: |
./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -Pspark-4.0 -DskipTests
./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Pspark-4.0 -Psemanticdb -DskipTests
3 changes: 3 additions & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ jobs:
org.apache.comet.CometFuzzAggregateSuite
org.apache.comet.CometFuzzIcebergSuite
org.apache.comet.CometFuzzMathSuite
org.apache.comet.CometCodegenDispatchFuzzSuite
org.apache.comet.DataGeneratorSuite
- name: "shuffle"
value: |
Expand Down Expand Up @@ -387,6 +388,8 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
org.apache.comet.CometCodegenDispatchSmokeSuite
org.apache.comet.CometCodegenSourceSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ jobs:
org.apache.comet.CometFuzzAggregateSuite
org.apache.comet.CometFuzzIcebergSuite
org.apache.comet.CometFuzzMathSuite
org.apache.comet.CometCodegenDispatchFuzzSuite
org.apache.comet.DataGeneratorSuite
- name: "shuffle"
value: |
Expand Down Expand Up @@ -234,6 +235,8 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
org.apache.comet.CometCodegenDispatchSmokeSuite
org.apache.comet.CometCodegenSourceSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
Expand Down
68 changes: 68 additions & 0 deletions common/src/main/java/org/apache/comet/udf/CometBatchKernel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf;

import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;

/**
* Abstract base extended by the Janino-compiled batch kernel emitted by {@code
* CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's
* {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries
* typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow
* read/write fuse into one method per expression tree.
*
* <p>Input scope: any {@code ValueVector[]}; the generated subclass casts each slot to the concrete
* Arrow type the compile-time schema specified. Output is a generic {@code FieldVector}; the
* generated subclass casts to the concrete type matching the bound expression's {@code dataType}.
* Widen input support by adding vector classes to the getter switch in {@code
* CometBatchKernelCodegen.emitTypedGetters}; widen output support by adding cases in {@code
* CometBatchKernelCodegen.allocateOutput} and {@code emitOutputWriter}.
*/
public abstract class CometBatchKernel extends CometInternalRow {

protected final Object[] references;

protected CometBatchKernel(Object[] references) {
this.references = references;
}

/**
* Process one batch.
*
* @param inputs Arrow input vectors; length and concrete classes must match the schema the kernel
* was compiled against
* @param output Arrow output vector; caller allocates to the expression's {@code dataType}
* @param numRows number of rows in this batch
*/
public abstract void process(ValueVector[] inputs, FieldVector output, int numRows);

/**
* Run partition-dependent initialization. The generated subclass overrides this to execute
* statements collected via {@code CodegenContext.addPartitionInitializationStatement}, for
* example reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}.
* Deterministic expressions leave this as a no-op.
*
* <p>The caller must invoke this before the first {@code process} call of each partition. The
* generated subclass is not thread-safe across concurrent {@code process} calls, so kernels are
* allocated per dispatcher invocation and init is run once on the fresh instance.
*/
public void init(int partitionIndex) {}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think moving init before process helps with reading this

}
130 changes: 103 additions & 27 deletions common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,48 @@
import org.apache.arrow.vector.ValueVector;
import org.apache.spark.TaskContext;
import org.apache.spark.comet.CometTaskContextShim;
import org.apache.spark.util.TaskCompletionListener;

/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
* pattern used by CometScalarSubquery so the native side can dispatch via
* call_static_method_unchecked.
*
* <p>Cache invariants:
*
* <ol>
* <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
* name.
* <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
* it. Two task attempts observing the same class name receive distinct instances.
* <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
* taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
* polling one future per worker at a time.
* <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
* the first cache miss for that task. No cache entry outlives its task.
* <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
* key {@code -1L} is used; that bucket is never evicted because no task-completion event will
* fire.
* </ol>
*
* <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
* work-stealing: on the scan-free execution path the same Spark task can be polled by different
* Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
* task attempt ID is stable for the life of the task regardless of which worker is polling.
*/
public class CometUdfBridge {

// Process-wide cache of UDF instances keyed by class name. CometUDF
// implementations are required to be stateless (see CometUDF), so a
// single shared instance per class is safe across native worker threads.
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
/**
* Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
* {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
* class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
* registered on the first cache miss per task.
*/
private static final ConcurrentHashMap<Long, ConcurrentHashMap<String, CometUDF>> INSTANCES =
new ConcurrentHashMap<>();

/** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
private static final long NO_TASK_ID = -1L;

/**
* Called from native via JNI.
Expand All @@ -50,15 +80,21 @@ public class CometUdfBridge {
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
* @param numRows row count of the current batch. Mirrors DataFusion's {@code
* ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
* zero-arg non-deterministic ScalaUDF) ever sees.
* @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
* {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
* left on a worker by a previous task.
* @param numRows number of rows in the current batch. Mirrors DataFusion's {@code
* ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for cases
* where no input arg is a batch-length array (e.g. a zero-arg non-deterministic ScalaUDF).
* UDFs that already read size from their input vectors can ignore it.
* @param taskContext Spark {@link TaskContext} captured on the driving Spark task thread and
* passed through from native. May be {@code null} when the bridge is invoked outside a Spark
* task (unit tests, direct native driver runs). When non-null and the current thread has no
* {@code TaskContext} of its own, the bridge installs it as the thread-local for the duration
* of the UDF call so the UDF body (including partition-sensitive built-ins like {@code Rand}
* / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
* TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
* is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
* invocations. The task attempt ID drawn from this context also keys the UDF-instance cache,
* so a UDF holding per-task state in fields sees a consistent instance for every call within
* the task regardless of which Tokio worker is polling.
*/
public static void evaluate(
String udfClassName,
Expand All @@ -68,23 +104,34 @@ public static void evaluate(
long outSchemaPtr,
int numRows,
TaskContext taskContext) {
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
// truth for this call. Any value already on the thread is either (a) the same object on a
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
TaskContext prior = TaskContext.get();
if (taskContext != null) {
assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty";
assert inputArrayPtrs != null && inputSchemaPtrs != null
: "input pointer arrays must be non-null";
assert inputArrayPtrs.length == inputSchemaPtrs.length
: "input array pointer count must equal schema pointer count";
assert numRows >= 0 : "numRows must be non-negative";
assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer";
assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer";

boolean installedTaskContext = false;
if (taskContext != null && TaskContext.get() == null) {
CometTaskContextShim.set(taskContext);
installedTaskContext = true;
assert TaskContext.get() == taskContext
: "TaskContext install did not take effect on this thread";
}
try {
evaluateInternal(
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
udfClassName,
inputArrayPtrs,
inputSchemaPtrs,
outArrayPtr,
outSchemaPtr,
numRows,
taskContext);
} finally {
if (taskContext != null) {
if (prior != null) {
CometTaskContextShim.set(prior);
} else {
CometTaskContextShim.unset();
}
if (installedTaskContext) {
CometTaskContextShim.unset();
}
}
}
Expand All @@ -95,9 +142,34 @@ private static void evaluateInternal(
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr,
int numRows) {
CometUDF udf =
int numRows,
TaskContext taskContext) {
long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID;

ConcurrentHashMap<String, CometUDF> perTask =
INSTANCES.computeIfAbsent(
taskAttemptId,
id -> {
ConcurrentHashMap<String, CometUDF> fresh = new ConcurrentHashMap<>();
if (taskContext != null) {
// computeIfAbsent runs this lambda at most once per key, so the listener is
// registered exactly once per task attempt.
taskContext.addTaskCompletionListener(
(TaskCompletionListener)
ctx -> {
ConcurrentHashMap<String, CometUDF> removed = INSTANCES.remove(id);
assert removed != null
: "task-completion listener fired but cache already removed "
+ "entry for task "
+ id;
});
}
return fresh;
});
assert perTask != null : "per-task cache must be non-null after computeIfAbsent";

CometUDF udf =
perTask.computeIfAbsent(
udfClassName,
name -> {
try {
Expand All @@ -113,6 +185,7 @@ private static void evaluateInternal(
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
}
});
assert udf != null : "reflective instantiation returned null for " + udfClassName;

BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();

Expand All @@ -126,6 +199,9 @@ private static void evaluateInternal(
}

result = udf.evaluate(inputs, numRows);
assert result instanceof FieldVector
: "CometUDF implementations must return FieldVector; got "
+ (result == null ? "null" : result.getClass().getName());
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
Expand Down
40 changes: 40 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,46 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val REGEXP_ENGINE_RUST = "rust"
val REGEXP_ENGINE_JAVA = "java"

val COMET_REGEXP_ENGINE: ConfigEntry[String] =
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the regexp work to test the new framework makes sense, but I think we should split this work out into a follow on PR

conf("spark.comet.exec.regexp.engine")
.category(CATEGORY_EXEC)
.doc(
"Experimental. Selects the engine used to evaluate supported regular-expression " +
s"expressions. `$REGEXP_ENGINE_RUST` uses the native DataFusion regexp engine. " +
s"`$REGEXP_ENGINE_JAVA` routes through a JVM-side UDF (java.util.regex.Pattern) for " +
"Spark-compatible semantics, at the cost of JNI roundtrips per batch. Expressions " +
"routed when set to java: rlike, regexp_extract, regexp_extract_all, regexp_replace, " +
"regexp_instr, and split.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
.checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA))
.createWithDefault(REGEXP_ENGINE_JAVA)

val CODEGEN_DISPATCH_AUTO = "auto"
val CODEGEN_DISPATCH_DISABLED = "disabled"
val CODEGEN_DISPATCH_FORCE = "force"

val COMET_CODEGEN_DISPATCH_MODE: ConfigEntry[String] =
conf("spark.comet.exec.codegenDispatch.mode")
.category(CATEGORY_EXEC)
.doc("Controls whether Comet routes eligible scalar expressions through the Arrow-direct " +
"codegen dispatcher (`CometCodegenDispatchUDF`) rather than through a native " +
s"DataFusion implementation or falling back to Spark. `$CODEGEN_DISPATCH_AUTO` lets " +
"each expression's serde decide its preferred path based on measured evidence " +
"(e.g. for regex, codegen is preferred when " +
s"spark.comet.exec.regexp.engine=$REGEXP_ENGINE_JAVA). " +
s"`$CODEGEN_DISPATCH_DISABLED` never uses codegen dispatch. `$CODEGEN_DISPATCH_FORCE` " +
"inverts the chain: every serde tries codegen first and falls through to its next " +
"preferred path only when `canHandle` rejects the expression. Useful for debugging " +
"and benchmarking.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
.checkValues(Set(CODEGEN_DISPATCH_AUTO, CODEGEN_DISPATCH_DISABLED, CODEGEN_DISPATCH_FORCE))
.createWithDefault(CODEGEN_DISPATCH_AUTO)

val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
.category(CATEGORY_SHUFFLE)
Expand Down
Loading
Loading