diff --git a/.github/workflows/pr_benchmark_check.yml b/.github/workflows/pr_benchmark_check.yml index b07cc03c34..a879493a7f 100644 --- a/.github/workflows/pr_benchmark_check.yml +++ b/.github/workflows/pr_benchmark_check.yml @@ -84,9 +84,7 @@ jobs: ${{ runner.os }}-benchmark-maven- - name: Check Scala compilation and linting - # Pin to spark-4.0 (Scala 2.13.16) because the default profile is now - # spark-4.1 / Scala 2.13.17, and semanticdb-scalac_2.13.17 is not yet - # published, which breaks `-Psemanticdb`. See pr_build_linux.yml for - # the same exclusion in the main lint matrix. + # Pinned to spark-4.0 because semanticdb-scalac_2.13.17 (spark-4.1 default) + # is not yet published, which breaks the -Psemanticdb scalafix lint. run: | - ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -Pspark-4.0 -DskipTests + ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Pspark-4.0 -Psemanticdb -DskipTests diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index dd5377ef54..313af77deb 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -309,6 +309,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenDispatchFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -387,6 +388,8 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenDispatchSmokeSuite + org.apache.comet.CometCodegenSourceSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 8abaa1c776..b2af6e43ab 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -157,6 +157,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenDispatchFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -234,6 +235,8 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenDispatchSmokeSuite + org.apache.comet.CometCodegenSourceSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java b/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java new file mode 100644 index 0000000000..f9fbb775a0 --- /dev/null +++ b/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Abstract base extended by the Janino-compiled batch kernel emitted by {@code + * CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's + * {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries + * typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow + * read/write fuse into one method per expression tree. + * + *

Input scope: any {@code ValueVector[]}; the generated subclass casts each slot to the concrete + * Arrow type the compile-time schema specified. Output is a generic {@code FieldVector}; the + * generated subclass casts to the concrete type matching the bound expression's {@code dataType}. + * Widen input support by adding vector classes to the getter switch in {@code + * CometBatchKernelCodegen.emitTypedGetters}; widen output support by adding cases in {@code + * CometBatchKernelCodegen.allocateOutput} and {@code emitOutputWriter}. + */ +public abstract class CometBatchKernel extends CometInternalRow { + + protected final Object[] references; + + protected CometBatchKernel(Object[] references) { + this.references = references; + } + + /** + * Process one batch. + * + * @param inputs Arrow input vectors; length and concrete classes must match the schema the kernel + * was compiled against + * @param output Arrow output vector; caller allocates to the expression's {@code dataType} + * @param numRows number of rows in this batch + */ + public abstract void process(ValueVector[] inputs, FieldVector output, int numRows); + + /** + * Run partition-dependent initialization. The generated subclass overrides this to execute + * statements collected via {@code CodegenContext.addPartitionInitializationStatement}, for + * example reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}. + * Deterministic expressions leave this as a no-op. + * + *

The caller must invoke this before the first {@code process} call of each partition. The + * generated subclass is not thread-safe across concurrent {@code process} calls, so kernels are + * allocated per dispatcher invocation and init is run once on the fresh instance. + */ + public void init(int partitionIndex) {} +} diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index 5e76819810..9e97ef2226 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -29,18 +29,48 @@ import org.apache.arrow.vector.ValueVector; import org.apache.spark.TaskContext; import org.apache.spark.comet.CometTaskContextShim; +import org.apache.spark.util.TaskCompletionListener; /** * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method * pattern used by CometScalarSubquery so the native side can dispatch via * call_static_method_unchecked. + * + *

Cache invariants: + * + *

    + *
  1. For each live Spark task attempt there is at most one {@link CometUDF} instance per class + * name. + *
  2. A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated + * it. Two task attempts observing the same class name receive distinct instances. + *
  3. At any instant at most one thread is inside {@code evaluate()} for a given {@code + * taskAttemptId}. This follows from Spark executing one native future per partition and Tokio + * polling one future per worker at a time. + *
  4. All instances for a task are dropped by the {@link TaskCompletionListener} registered on + * the first cache miss for that task. No cache entry outlives its task. + *
  5. When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback + * key {@code -1L} is used; that bucket is never evicted because no task-completion event will + * fire. + *
+ * + *

Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio + * work-stealing: on the scan-free execution path the same Spark task can be polled by different + * Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The + * task attempt ID is stable for the life of the task regardless of which worker is polling. */ public class CometUdfBridge { - // Process-wide cache of UDF instances keyed by class name. CometUDF - // implementations are required to be stateless (see CometUDF), so a - // single shared instance per class is safe across native worker threads. - private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>(); + /** + * Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or + * {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF + * class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener} + * registered on the first cache miss per task. + */ + private static final ConcurrentHashMap> INSTANCES = + new ConcurrentHashMap<>(); + + /** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */ + private static final long NO_TASK_ID = -1L; /** * Called from native via JNI. @@ -58,7 +88,9 @@ public class CometUdfBridge { * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}. * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext - * left on a worker by a previous task. + * left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache, + * so a UDF holding per-task state in fields sees a consistent instance for every call within + * the task regardless of which Tokio worker is polling. */ public static void evaluate( String udfClassName, @@ -68,16 +100,33 @@ public static void evaluate( long outSchemaPtr, int numRows, TaskContext taskContext) { - // Save-and-restore rather than only-install-if-null: the propagated context is the ground - // truth for this call. Any value already on the thread is either (a) the same object on a - // Spark task thread, or (b) stale from a prior task on a reused Tokio worker. + assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty"; + assert inputArrayPtrs != null && inputSchemaPtrs != null + : "input pointer arrays must be non-null"; + assert inputArrayPtrs.length == inputSchemaPtrs.length + : "input array pointer count must equal schema pointer count"; + assert numRows >= 0 : "numRows must be non-negative"; + assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer"; + assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer"; + + // Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the + // ground truth for this call. Any value already on the thread is either (a) the same object + // on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker. TaskContext prior = TaskContext.get(); if (taskContext != null) { CometTaskContextShim.set(taskContext); + assert TaskContext.get() == taskContext + : "TaskContext install did not take effect on this thread"; } try { evaluateInternal( - udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); + udfClassName, + inputArrayPtrs, + inputSchemaPtrs, + outArrayPtr, + outSchemaPtr, + numRows, + taskContext); } finally { if (taskContext != null) { if (prior != null) { @@ -95,9 +144,34 @@ private static void evaluateInternal( long[] inputSchemaPtrs, long outArrayPtr, long outSchemaPtr, - int numRows) { - CometUDF udf = + int numRows, + TaskContext taskContext) { + long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID; + + ConcurrentHashMap perTask = INSTANCES.computeIfAbsent( + taskAttemptId, + id -> { + ConcurrentHashMap fresh = new ConcurrentHashMap<>(); + if (taskContext != null) { + // computeIfAbsent runs this lambda at most once per key, so the listener is + // registered exactly once per task attempt. + taskContext.addTaskCompletionListener( + (TaskCompletionListener) + ctx -> { + ConcurrentHashMap removed = INSTANCES.remove(id); + assert removed != null + : "task-completion listener fired but cache already removed " + + "entry for task " + + id; + }); + } + return fresh; + }); + assert perTask != null : "per-task cache must be non-null after computeIfAbsent"; + + CometUDF udf = + perTask.computeIfAbsent( udfClassName, name -> { try { @@ -113,6 +187,7 @@ private static void evaluateInternal( throw new RuntimeException("Failed to instantiate CometUDF: " + name, e); } }); + assert udf != null : "reflective instantiation returned null for " + udfClassName; BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 9b376837f7..feb4129ac5 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -380,6 +380,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.scalaUDF.codegen.enabled") + .category(CATEGORY_EXEC) + .doc( + "Whether to route Spark `ScalaUDF` expressions through Comet's Arrow-direct codegen " + + "dispatcher. When enabled, a supported ScalaUDF is compiled into a per-batch kernel " + + "that reads and writes Arrow vectors directly from native execution. When disabled, " + + "plans containing a ScalaUDF fall back to Spark for the enclosing operator.") + .booleanConf + .createWithDefault(true) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala b/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala new file mode 100644 index 0000000000..1696c466a3 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Throwing-default base for [[ArrayData]] in the Arrow-direct codegen kernel. Subclasses override + * only the getters their element type needs (e.g. `numElements`, `isNullAt`, `getUTF8String` for + * an `ArrayType(StringType)` input). + * + * Consumer: `InputArray_${path}` nested classes the input emitter generates per `ArrayType` input + * column. They back `getArray(ord)` plus the recursion for `Array>` and array-typed + * map keys / struct fields. + * + * `ArrayData` and [[CometInternalRow]]'s [[InternalRow]] are sibling abstract classes in Spark + * (both extend `SpecializedGetters`, neither inherits the other), so a base aimed at one cannot + * serve the other. The dispatch body that '''is''' shared between them lives in + * [[CometSpecializedGettersDispatch]]. The third sibling, [[CometMapData]], backs `InputMap_*` + * and routes `keyArray()` / `valueArray()` through `CometArrayData` instances. + * + * Mixes in [[CometInternalRowShim]] for the same reason `CometInternalRow` does: Spark 4.x adds + * abstract `SpecializedGetters` methods (`getVariant`, `getGeography`, `getGeometry`) that both + * `InternalRow` and `ArrayData` inherit; the per-profile shim provides throwing defaults. + */ +abstract class CometArrayData extends ArrayData with CometInternalRowShim { + + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + + override def get(ordinal: Int, dataType: DataType): AnyRef = + CometSpecializedGettersDispatch.get(this, ordinal, dataType) + + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + + override def getByte(ordinal: Int): Byte = unsupported("getByte") + + override def getShort(ordinal: Int): Short = unsupported("getShort") + + override def getInt(ordinal: Int): Int = unsupported("getInt") + + override def getLong(ordinal: Int): Long = unsupported("getLong") + + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this array shape") + + override def update(i: Int, value: Any): Unit = unsupported("update") + + override def copy(): ArrayData = unsupported("copy") + + override def array: Array[Any] = unsupported("array") + + override def toString(): String = { + val n = + try numElements().toString + catch { + case _: Throwable => "?" + } + s"${getClass.getSimpleName}(numElements=$n)" + } + + override def numElements(): Int = unsupported("numElements") +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala new file mode 100644 index 0000000000..bf5a9eaa4b --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala @@ -0,0 +1,535 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.shims.CometExprTraitShim + +/** + * Compiles a bound [[Expression]] plus an input schema into a [[CometBatchKernel]] that fuses + * Arrow input reads, expression evaluation, and Arrow output writes into one Janino-compiled + * method per (expression, schema) pair. + * + * The kernel is generic over Catalyst expressions; it does not know or assume that the bound tree + * came from a `ScalaUDF`. Today's only consumer is + * [[org.apache.comet.udf.codegen.CometScalaUDFCodegen]], but a future consumer (Spark + * `WholeStageCodegenExec` integration, a non-UDF batch evaluator) can drive this class directly. + * + * Constraints: single output vector per kernel (whole projections need a multi-output extension); + * per-row scalar evaluation only (aggregation, window, generator rejected by [[canHandle]]). + * + * Input- and output-side emission live in [[CometBatchKernelCodegenInput]] and + * [[CometBatchKernelCodegenOutput]]. This file owns the [[ArrowColumnSpec]] vocabulary, the + * [[canHandle]] / [[allocateOutput]] / [[compile]] / [[generateSource]] entry points, and + * cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant). + * + * The generated kernel '''is''' the `InternalRow` that Spark's `BoundReference.genCode` reads + * from: `ctx.INPUT_ROW = "row"` plus `InternalRow row = this;` inside `process` routes + * `row.getUTF8String(ord)` to the kernel's own typed getter (final method, constant ordinal; JIT + * devirtualizes and folds). `row` rather than `this` because Spark's `splitExpressions` passes + * INPUT_ROW as a helper-method parameter name and `this` is a reserved Java keyword. + */ +object CometBatchKernelCodegen extends Logging with CometExprTraitShim { + + /** + * Resolve an Arrow vector class by its simple name, using the same classloader the codegen uses + * internally. Intended for tests: the `common` module shades `org.apache.arrow` to + * `org.apache.comet.shaded.arrow`, so `classOf[VarCharVector]` at a call site in an unshaded + * module refers to a different [[Class]] object than the one the codegen compares against. + * Callers pass a simple name and get back the class the production code actually uses. + */ + def vectorClassBySimpleName(name: String): Class[_ <: ValueVector] = name match { + case "BitVector" => classOf[BitVector] + case "TinyIntVector" => classOf[TinyIntVector] + case "SmallIntVector" => classOf[SmallIntVector] + case "IntVector" => classOf[IntVector] + case "BigIntVector" => classOf[BigIntVector] + case "Float4Vector" => classOf[Float4Vector] + case "Float8Vector" => classOf[Float8Vector] + case "DecimalVector" => classOf[DecimalVector] + case "DateDayVector" => classOf[DateDayVector] + case "TimeStampMicroVector" => classOf[TimeStampMicroVector] + case "TimeStampMicroTZVector" => classOf[TimeStampMicroTZVector] + case "VarCharVector" => classOf[VarCharVector] + case "VarBinaryVector" => classOf[VarBinaryVector] + case other => throw new IllegalArgumentException(s"unknown Arrow vector class: $other") + } + + /** + * Type surface the kernel covers, on both the input getter side and the output writer side. + * Recursive: `ArrayType` / `StructType` / `MapType` are supported when their children are. + * Input and output use a single predicate today; if they ever need to diverge, split this back + * into per-direction methods. + */ + def isSupportedDataType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case _: StringType | _: BinaryType => true + case DateType | TimestampType | TimestampNTZType => true + case ArrayType(inner, _) => isSupportedDataType(inner) + case st: StructType => st.fields.forall(f => isSupportedDataType(f.dataType)) + case mt: MapType => isSupportedDataType(mt.keyType) && isSupportedDataType(mt.valueType) + case _ => false + } + + /** + * Plan-time predicate: can the codegen dispatcher handle this bound expression end to end? + * `None` greenlights the serde to emit the codegen proto; `Some(reason)` forces a Spark + * fallback (typically `withInfo(...) + None`) rather than crashing the Janino compile at + * execute time. + * + * Checks every `BoundReference`'s data type and the root `expr.dataType` against + * [[isSupportedDataType]], and rejects aggregates / generators. Intermediate nodes are not + * checked: only leaves (row reads) and the root (output write) touch Arrow. + */ + def canHandle(boundExpr: Expression): Option[String] = { + if (!isSupportedDataType(boundExpr.dataType)) { + return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}") + } + // Reject expressions that can't be safely compiled or cached: + // - AggregateFunction / Generator: non-scalar bridge shape. + // - CodegenFallback: opts out of `doGenCode`, which our compile path assumes works. + // Passing one in would emit interpreted-eval glue that our kernel can't splice cleanly. + // - Unevaluable: unresolved plan markers. Shouldn't reach a serde, but cheap to guard. + // `isCodegenInertUnevaluable` lets the shim exclude version-specific leaves that are + // `Unevaluable` but never touched by codegen (e.g. Spark 4.0's `ResolvedCollation`, which + // lives in `Collate.collation` as a type marker; `Collate.genCode` delegates to its child). + // + // TODO(hof-lambdas): the `CodegenFallback` rule rejects `NamedLambdaVariable`, which flags + // every higher-order function (`ArrayTransform`, `ArrayAggregate`, `ArrayExists`, + // `ArrayFilter`, `ZipWith`, `MapFilter`, etc.) as unsupported. The variable is + // `CodegenFallback` only in isolation; the surrounding HOF binds its `value` field inline + // as part of its own + // `doGenCode`, and the resulting Java compiles fine. Loosening this would unlock + // element-iteration over `Array` / `Array` which today have no fuzz path + // (`array_max` doesn't apply to non-comparable elements, generators are blocked above). Plan: + // allow `NamedLambdaVariable` / `LambdaFunction` in the rejection scan; verify the kernel + // splices the HOF's emitted loop without ctx.references collisions on the lambda holder. + // + // Nondeterministic / stateful expressions are accepted: per-partition kernel allocation + // (`CometScalaUDFCodegen.ensureKernel`) plus a single `init(partitionIndex)` call at + // partition entry give `Rand` / `MonotonicallyIncreasingID` / etc. correct state across + // batches and a clean reset across partitions. + // + // `ExecSubqueryExpression` (`ScalarSubquery`, `InSubqueryExec`) is accepted via a chain: + // the surrounding Comet operator's inherited `SparkPlan.waitForSubqueries` populates the + // subquery's mutable `result` field before evaluation; the closure serializer captures that + // populated value into the arg-0 bytes; the dispatcher keys its compile cache on those + // exact bytes, so distinct subquery results produce distinct cache entries with no + // cross-query staleness. Refactors to the cache-key derivation, the transport, or any + // Comet operator that bypasses `waitForSubqueries` would break this; preserve it. + boundExpr.find { + case _: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction => true + case _: org.apache.spark.sql.catalyst.expressions.Generator => true + case _: CodegenFallback => true + case u: Unevaluable if isCodegenInertUnevaluable(u) => false + case _: Unevaluable => true + case _ => false + } match { + case Some(bad) => + return Some( + s"codegen dispatch: expression ${bad.getClass.getSimpleName} not supported " + + "(aggregate, generator, codegen-fallback, or unevaluable)") + case None => + } + val badRef = boundExpr.collectFirst { + case b: BoundReference if !isSupportedDataType(b.dataType) => + b + } + badRef.map(b => + s"codegen dispatch: unsupported input type ${b.dataType} at ordinal ${b.ordinal}") + } + + /** + * Allocate an Arrow output vector matching the expression's `dataType`. Thin forwarder to + * [[CometBatchKernelCodegenOutput.allocateOutput]]. Kept on this object as part of the public + * API so external callers (`CometScalaUDFCodegen`) do not have to know about the internal + * split. + */ + def allocateOutput( + dataType: DataType, + name: String, + numRows: Int, + estimatedBytes: Int = -1): FieldVector = + CometBatchKernelCodegenOutput.allocateOutput(dataType, name, numRows, estimatedBytes) + + /** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */ + def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector = + CometBatchKernelCodegenOutput.allocateOutput(field, numRows, estimatedBytes) + + def compile(boundExpr: Expression, inputSchema: Seq[ArrowColumnSpec]): CompiledKernel = { + val src = generateSource(boundExpr, inputSchema) + val (clazz, _) = + try { + CodeGenerator.compile(src.code) + } catch { + case t: Throwable => + logError( + s"CometBatchKernelCodegen: compile failed for ${boundExpr.getClass.getSimpleName}. " + + s"Generated source follows:\n${src.body}", + t) + throw t + } + logInfo( + s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " + + s"-> ${boundExpr.dataType} inputs=" + + inputSchema + .map(s => s"${s.vectorClass.getSimpleName}${if (s.nullable) "?" else ""}") + .mkString(",")) + // Freshen references per kernel allocation. See the `CompiledKernel` scaladoc for why. + // `generateSource` is pure with respect to its inputs (no hidden state) and produces a + // layout-compatible references array each time because the expression and schema are + // fixed. + val freshReferences: () => Array[Any] = () => + generateSource(boundExpr, inputSchema).references + CompiledKernel(clazz, freshReferences) + } + + /** + * Generate the Java source for a kernel without compiling it. Factored out of [[compile]] so + * tests can assert on the emitted source (null short-circuit present, non-nullable `isNullAt` + * returns literal `false`, etc.) without paying for Janino. + */ + def generateSource( + boundExpr: Expression, + inputSchema: Seq[ArrowColumnSpec]): GeneratedSource = { + val ctx = new CodegenContext + // `BoundReference.genCode` emits `${ctx.INPUT_ROW}.getUTF8String(ord)`. We alias a local + // `row` to `this` at the top of `process` so those reads resolve to the kernel's own typed + // getters (virtual dispatch on a concrete final class, JIT devirtualizes + folds the + // switch). `row` rather than `this` because Spark's `splitExpressions` uses INPUT_ROW as the + // parameter name of any helper method it emits; `this` is a reserved keyword, so using it + // as a parameter name produces `private UTF8String helper(InternalRow this)` which Janino + // rejects. + ctx.INPUT_ROW = "row" + + val baseClass = classOf[CometBatchKernel].getName + // Resolve shaded Arrow class names at compile time so generated source + // matches the abstract method signature after Maven relocation. + val valueVectorClass = classOf[ValueVector].getName + val fieldVectorClass = classOf[FieldVector].getName + + // Build the per-row body via Spark's doGenCode. + // + // `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex + // output) that `emitOutputWriter` factors out of the per-row body so they do not repeat on + // every row. Scalar outputs return an empty string here. + // + // TODO(method-size): perRowBody is inlined inside process's for-loop and not split. + // Sufficiently deep trees can exceed Janino's 64KB method size; wrap in + // ctx.splitExpressionsWithCurrentInputs when hit. + val (concreteOutClass, outputSetup, perRowBody) = { + // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the + // hood, which populates `ctx.subexprFunctions` with per-row helper calls that write + // common subexpression results into `addMutableState`-allocated fields; the returned + // `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated + // helper invocation block, spliced into the per-row body by `defaultBody` (inside the + // NullIntolerant else-branch when that short-circuit fires, otherwise before + // `ev.code`). See the "Subexpression elimination" section of the object-level + // Scaladoc for why we use this variant rather than the WSCG one. + val ev = if (SQLConf.get.subexpressionEliminationEnabled) { + ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head + } else { + boundExpr.genCode(ctx) + } + val subExprsCode = ctx.subexprFunctionsCode + val (cls, setup, snippet) = + CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx) + (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode)) + } + + val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema) + val typedInputCasts = CometBatchKernelCodegenInput.emitInputCasts(inputSchema) + val decimalTypeByOrdinal = CometBatchKernelCodegenInput.decimalPrecisionByOrdinal(boundExpr) + val getters = + CometBatchKernelCodegenInput.emitTypedGetters(inputSchema, decimalTypeByOrdinal) + val nested = CometBatchKernelCodegenInput.emitNestedClasses(inputSchema) + val getArrayMethod = CometBatchKernelCodegenInput.emitGetArrayMethod(inputSchema) + val getStructMethod = CometBatchKernelCodegenInput.emitGetStructMethod(inputSchema) + val getMapMethod = CometBatchKernelCodegenInput.emitGetMapMethod(inputSchema) + + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificCometBatchKernel(references); + |} + | + |final class SpecificCometBatchKernel extends $baseClass { + | + | ${ctx.declareMutableStates()} + | + | $typedFieldDecls + | private int rowIdx; + | + | public SpecificCometBatchKernel(Object[] references) { + | super(references); + | ${ctx.initMutableStates()} + | } + | + | @Override + | public void init(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | $getters + | $getArrayMethod + | $getStructMethod + | $getMapMethod + | + | @Override + | public void process( + | $valueVectorClass[] inputs, + | $fieldVectorClass outRaw, + | int numRows) { + | $concreteOutClass output = ($concreteOutClass) outRaw; + | $typedInputCasts + | $outputSetup + | // Alias the kernel as `row` so Spark-generated `${ctx.INPUT_ROW}.method()` reads + | // resolve to the kernel's own typed getters. Helper methods that Spark splits off + | // via `splitExpressions` also take `InternalRow row` as a parameter; we pass `this` + | // implicitly since callers substitute INPUT_ROW which we've set to `row`. + | org.apache.spark.sql.catalyst.InternalRow row = this; + | for (int i = 0; i < numRows; i++) { + | this.rowIdx = i; + | $perRowBody + | } + | } + | + | ${ctx.declareAddedFunctions()} + | + |$nested + |} + """.stripMargin + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + GeneratedSource(code.body, code, ctx.references.toArray) + } + + /** + * Per-row body for the default path. For `NullIntolerant` expressions (null in any input -> + * null output), prepends a short-circuit that skips expression evaluation entirely when any + * input column is null this row, saving the full `ev.code` cost. Otherwise the standard shape: + * run `ev.code`, then `setNull` or write based on `ev.isNull`. + * + * `subExprsCode` is the CSE helper-invocation block; it writes common subexpression results + * into class fields that `ev.code` reads, so it must run before `ev.code`. Inside the + * short-circuit it lives in the else branch, skipping CSE for null rows. Empty when CSE is + * disabled or the tree has none. + */ + private def defaultBody( + boundExpr: Expression, + ev: ExprCode, + writeSnippet: String, + subExprsCode: String): String = { + boundExpr match { + case _ if isNullIntolerant(boundExpr) && allNullIntolerant(boundExpr) => + // Every node from root to leaf is either NullIntolerant or a leaf. That transitively + // guarantees "any BoundReference null at this row -> whole expression null", so we can + // short-circuit on the union of input ordinals. Breaking the chain with a non-null- + // propagating node like `coalesce` or `if` produces the wrong result (coalesce(null,x) + // is x, not null), so the check above rejects those shapes and falls through to the + // default branch which runs Spark's own null-aware ev.code. + val inputOrdinals = + boundExpr.collect { case b: BoundReference => b.ordinal }.distinct + val nullCheck = + if (inputOrdinals.isEmpty) "false" + else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ") + s""" + |if ($nullCheck) { + | output.setNull(i); + |} else { + | $subExprsCode + | ${ev.code} + | $writeSnippet + |} + """.stripMargin + case _ => + // Optimization: NonNullableOutputShortCircuit. + // When the bound expression declares `nullable = false`, the `if (ev.isNull)` branch is + // dead and HotSpot may or may not fold it (it depends on whether the expression's + // `doGenCode` made `ev.isNull` a `FalseLiteral` or a variable whose value is + // false-at-runtime but not a compile-time constant from Spark's side). Drop the guard + // at source level so we don't depend on JIT folding and keep the generated body + // minimal. + if (!boundExpr.nullable) { + s""" + |$subExprsCode + |${ev.code} + |$writeSnippet + """.stripMargin + } else { + s""" + |$subExprsCode + |${ev.code} + |if (${ev.isNull}) { + | output.setNull(i); + |} else { + | $writeSnippet + |} + """.stripMargin + } + } + } + + /** + * True iff every node in the expression tree is either `NullIntolerant` or a leaf we can safely + * consider null-propagating (`BoundReference` and `Literal`). Used to gate the `NullIntolerant` + * short-circuit in [[defaultBody]]: the short-circuit collects `BoundReference` ordinals from + * the whole tree and skips `ev.code` when any of them is null, which is only correct when every + * path from a leaf to the root propagates nulls. A non- propagating node (`Coalesce`, `If`, + * `CaseWhen`, `Concat`, etc.) anywhere in the tree invalidates this assumption: `coalesce(null, + * x)` is `x`, not null, so pre-nulling on any input null would produce the wrong result. + */ + private def allNullIntolerant(expr: Expression): Boolean = + !expr.exists { + case _: BoundReference | _: Literal => false + case other => !isNullIntolerant(other) + } + + /** + * Per-column compile-time invariants. The concrete Arrow vector class and whether the column is + * nullable are baked into the generated kernel's typed fields and branches. Part of the cache + * key: different vector classes or nullability produce different kernels. + * + * Sealed hierarchy so that complex types (array/map/struct) can carry their nested element + * shape recursively. Today scalar, array, and struct specs exist; map cases will land as an + * additional subclass when the emitter covers them. A companion `apply` / `unapply` preserves + * the original scalar-only construction and extractor shape so existing callers don't need to + * change. + */ + sealed trait ArrowColumnSpec { + def vectorClass: Class[_ <: ValueVector] + + def nullable: Boolean + } + + /** Scalar column: one Arrow vector class per row slot, no nested structure. */ + final case class ScalarColumnSpec(vectorClass: Class[_ <: ValueVector], nullable: Boolean) + extends ArrowColumnSpec + + /** + * Array column: an Arrow `ListVector` wrapping a child spec. `elementSparkType` is the Spark + * `DataType` of the element so the nested-class getter emitter can choose the right template + * (e.g. `getUTF8String` for `StringType`, `getInt` for `IntegerType`). The child spec carries + * the Arrow child vector class. Nested arrays (`Array>`) work by the child being + * itself an `ArrayColumnSpec`. + */ + final case class ArrayColumnSpec( + nullable: Boolean, + elementSparkType: DataType, + element: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[ListVector] + } + + /** + * Struct column: an Arrow `StructVector` wrapping N typed child specs. Each entry carries the + * Spark field name (for schema identification in the cache key), the Spark `DataType` of the + * field (so per-field emitters pick the right read/write template), the child `ArrowColumnSpec` + * (so nested shapes like `Struct>` compose by trait-level recursion), and the + * field's `nullable` bit (so non-nullable fields elide their per-row null check at source + * level). Nested structs (`Struct>`) work by the child being itself a + * `StructColumnSpec`. + */ + final case class StructColumnSpec(nullable: Boolean, fields: Seq[StructFieldSpec]) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[StructVector] + } + + /** One field entry on a [[StructColumnSpec]]. */ + final case class StructFieldSpec( + name: String, + sparkType: DataType, + nullable: Boolean, + child: ArrowColumnSpec) + + /** + * Map column: an Arrow `MapVector` (subclass of `ListVector`) whose data vector is a + * `StructVector` with a key field at ordinal 0 and a value field at ordinal 1. `key` and + * `value` are themselves `ArrowColumnSpec` so nested shapes (`Map, Array>`, + * `Map, ...>`) compose by trait-level recursion. Nullable map entries are controlled + * per-column by the outer map's validity; nullable keys and values are carried in the child + * specs' `nullable` bit. + */ + final case class MapColumnSpec( + nullable: Boolean, + keySparkType: DataType, + valueSparkType: DataType, + key: ArrowColumnSpec, + value: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[MapVector] + } + + /** + * Result of compiling a bound [[Expression]] into a Janino kernel. The Spark-generated + * `factory` is stateless and safe to share across partitions; `freshReferences` regenerates the + * references array per kernel allocation. + * + * The references array can't be cached because some expressions (notably [[ScalaUDF]]) embed + * stateful `ExpressionEncoder` serializers via `ctx.addReferenceObj` that reuse an internal + * `UnsafeRow` / `byte[]` per `.apply(...)`. Sharing one serializer across partition kernels + * would race on that buffer. Re-running `genCode` is microseconds; Janino compile is + * milliseconds. Cache the expensive piece, refresh the cheap one. + * + * Mirrors Spark `WholeStageCodegenExec`: compile once per plan, instantiate per partition, + * `init(partitionIndex)`, iterate. + */ + final case class CompiledKernel(factory: GeneratedClass, freshReferences: () => Array[Any]) { + def newInstance(): CometBatchKernel = + factory.generate(freshReferences()).asInstanceOf[CometBatchKernel] + } + + /** + * Output of [[generateSource]]. `body` is the raw Java source Janino will compile; `code` is + * the post-`stripOverlappingComments` wrapper Janino actually takes as input; `references` are + * the runtime objects the generated constructor pulls from via `ctx.addReferenceObj` (cached + * patterns, replacement strings, etc.). Tests inspect `body` to assert the shape of the + * generated source. See `CometCodegenSourceSuite` for examples. + */ + final case class GeneratedSource(body: String, code: CodeAndComment, references: Array[Any]) + + object ArrowColumnSpec { + + /** Convenience constructor producing a [[ScalarColumnSpec]]. */ + def apply(vectorClass: Class[_ <: ValueVector], nullable: Boolean): ArrowColumnSpec = + ScalarColumnSpec(vectorClass, nullable) + + /** + * Trait-level extractor that destructures only the scalar case. Pattern-match callers use + * `case ArrowColumnSpec(cls, nullable)` to filter on scalar specs and pull out their vector + * class and nullability in one step; complex specs return `None` and skip the case. + */ + def unapply(spec: ArrowColumnSpec): Option[(Class[_ <: ValueVector], Boolean)] = spec match { + case ScalarColumnSpec(c, n) => Some((c, n)) + case _ => None + } + } +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala new file mode 100644 index 0000000000..bebc2949e5 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -0,0 +1,1014 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import scala.collection.mutable + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.types._ + +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec} +import org.apache.comet.vector.CometPlainVector + +/** + * Input-side emitters for the Arrow-direct codegen kernel: kernel field declarations, per-batch + * input casts, top-level typed-getter switches, nested `InputArray_${path}` / + * `InputStruct_${path}` / `InputMap_${path}` classes per complex level, and the input-side + * type-support gate. Paired with [[CometBatchKernelCodegenOutput]] on the write side. + * + * Path encoding. Each position in the spec tree has a unique path string used as a suffix on + * vector fields and nested classes. From a column ordinal: root `col${ord}`, array element + * `${P}_e`, struct field `fi` `${P}_f${fi}`, map key `${P}_k`, map value `${P}_v`. + * + * Nested-class composition. A class at path `P` is a Spark `ArrayData` / `InternalRow` / + * `MapData` view of its Arrow vector. Each instance is allocated fresh per `getArray(i)` / + * `getStruct(i, n)` / `getMap(i)` call (constructor takes the slice and stores it in `final` + * fields), matching Spark's `ColumnarRow` / `ColumnarArray` model. JIT escape analysis usually + * scalarizes the allocation when the value is consumed locally; the consequence is that + * retain-by-reference consumers (e.g. `ArrayDistinct.nullSafeEval` stashing references in an + * `OpenHashSet`) get distinct identities and lazy reads work correctly. + */ +private[codegen] object CometBatchKernelCodegenInput { + + /** + * Primitive Arrow vector classes wrapped in [[CometPlainVector]] at input-cast time. + * `CometPlainVector.get*` reads use `Platform.get*` against a cached buffer address; JIT + * inlines to branchless reads. `getBoolean` also caches the data byte for bit-packed reads. + * + * Not wrapped: `DecimalVector` (kernel inlines its precision-keyed fast/slow split), + * `VarCharVector` / `VarBinaryVector` (kernel emits inline unsafe reads to skip the redundant + * `isNullAt` inside `getUTF8String` / `getBinary`). + */ + private val primitiveArrowClasses: Set[Class[_]] = Set( + classOf[BitVector], + classOf[TinyIntVector], + classOf[SmallIntVector], + classOf[IntVector], + classOf[BigIntVector], + classOf[Float4Vector], + classOf[Float8Vector], + classOf[DateDayVector], + classOf[TimeStampMicroVector], + classOf[TimeStampMicroTZVector]) + private val cometPlainVectorName: String = classOf[CometPlainVector].getName + + /** + * Emit the kernel's typed vector-field declarations for every level of every input column's + * spec tree. + */ + def emitInputFieldDecls(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectVectorFieldDecls(path, spec, lines) + } + lines.mkString("\n ") + } + + /** + * Emit the per-batch cast statements. For a map column, casts the outer `MapVector`, then casts + * the inner `StructVector` (via a local variable) to extract key and value children via + * `getChildByOrdinal(0)` / `(1)`. For arrays, casts the outer `ListVector` and recurses via + * `getDataVector()`. For structs, casts the outer `StructVector` and recurses via + * `getChildByOrdinal(fi)`. + */ + def emitInputCasts(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectCasts(path, spec, s"inputs[$ord]", lines) + } + lines.mkString("\n ") + } + + /** + * Emit the kernel's typed-getter overrides. Each switches on column ordinal; with the inlined + * constant ordinal from `BoundReference.genCode`, JIT folds the switch to one branch and + * devirtualizes thanks to the final class. + * + * `decimalTypeByOrdinal` lets the decimal getter specialize per ordinal: when only a + * `DecimalType(precision <= 18)` `BoundReference` reads that ordinal, the emitted case skips + * the `BigDecimal` allocation and reads the unscaled long directly. + * + * TODO(unsafe-readers): primitive `v.get(i)` performs a bounds check that is redundant given `i + * in [0, numRows)`. + */ + def emitTypedGetters( + inputSchema: Seq[ArrowColumnSpec], + decimalTypeByOrdinal: Map[Int, Option[DecimalType]]): String = { + val withOrd = inputSchema.zipWithIndex + + val isNullCases = withOrd.map { case (spec, ord) => + if (!spec.nullable) { + s" case $ord: return false;" + } else { + // CometPlainVector exposes `isNullAt`; Arrow-typed fields expose `isNull`. Both check + // the validity bitmap with the same semantics. + val method = spec.vectorClass match { + case cls if wrapsInCometPlainVector(cls) => "isNullAt" + case _ => "isNull" + } + s" case $ord: return this.col$ord.$method(this.rowIdx);" + } + } + + val booleanCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[BitVector] => + s" case $ord: return this.col$ord.getBoolean(this.rowIdx);" + } + val byteCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[TinyIntVector] => + s" case $ord: return this.col$ord.getByte(this.rowIdx);" + } + val shortCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[SmallIntVector] => + s" case $ord: return this.col$ord.getShort(this.rowIdx);" + } + val intCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[IntVector] || cls == classOf[DateDayVector] => + s" case $ord: return this.col$ord.getInt(this.rowIdx);" + } + val longCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[BigIntVector] || + cls == classOf[TimeStampMicroVector] || + cls == classOf[TimeStampMicroTZVector] => + s" case $ord: return this.col$ord.getLong(this.rowIdx);" + } + val floatCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float4Vector] => + s" case $ord: return this.col$ord.getFloat(this.rowIdx);" + } + val doubleCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float8Vector] => + s" case $ord: return this.col$ord.getDouble(this.rowIdx);" + } + val decimalCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[DecimalVector] => + val known = decimalTypeByOrdinal.getOrElse(ord, None) + val valueAddr = s"this.col${ord}_valueAddr" + val slowField = s"this.col$ord" + val fastPath = emitDecimalFastBodyUnsafe(valueAddr, "this.rowIdx", " ") + val slowPath = emitDecimalSlowBody(slowField, "this.rowIdx", " ") + val body = known match { + case Some(dt) if dt.precision <= Decimal.MAX_LONG_DIGITS => fastPath + case Some(_) => slowPath + case None => + s""" if (precision <= ${Decimal.MAX_LONG_DIGITS}) { + |$fastPath + | } else { + |$slowPath + | }""".stripMargin + } + s""" case $ord: { + |$body + | }""".stripMargin + } + val binaryCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarBinaryVector] => + s""" case $ord: { + |${emitBinaryBodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + val utf8Cases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarCharVector] => + s""" case $ord: { + |${emitUtf8BodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + + Seq( + emitOrdinalSwitch("public boolean isNullAt(int ordinal)", "isNullAt", isNullCases), + emitOrdinalSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + emitOrdinalSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + emitOrdinalSwitch("public short getShort(int ordinal)", "getShort", shortCases), + emitOrdinalSwitch("public int getInt(int ordinal)", "getInt", intCases), + emitOrdinalSwitch("public long getLong(int ordinal)", "getLong", longCases), + emitOrdinalSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + emitOrdinalSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + emitOrdinalSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + emitOrdinalSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + emitOrdinalSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + private def wrapsInCometPlainVector(cls: Class[_]): Boolean = + primitiveArrowClasses.contains(cls) + + private def emitOrdinalSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } + + private def emitDecimalSlowBody(field: String, idx: String, ind: String): String = { + val cont = ind + " " + s"""${ind}java.math.BigDecimal bd = $field.getObject($idx); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.apply(bd, precision, scale);""".stripMargin + } + + private def emitDecimalFastBodyUnsafe(valueAddr: String, idx: String, ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}long unscaled = org.apache.spark.unsafe.Platform.getLong(null, + |$cont$valueAddr + (long) $i * 16L); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.createUnsafe(unscaled, precision, scale);""".stripMargin + } + + private def emitUtf8BodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}return org.apache.spark.unsafe.types.UTF8String + |$cont.fromAddress(null, $valueAddr + s, e - s);""".stripMargin + } + + /** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */ + private def castableIdx(idx: String): String = if (idx.contains(' ')) s"($idx)" else idx + + private def emitBinaryBodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}int len = e - s; + |${ind}byte[] out = new byte[len]; + |${ind}org.apache.spark.unsafe.Platform.copyMemory(null, $valueAddr + s, out, + |${cont}org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET, len); + |${ind}return out;""".stripMargin + } + + /** + * Build a per-ordinal map of the `DecimalType` observed on `BoundReference`s in the bound + * expression. Used by [[emitTypedGetters]] to emit a compile-time-specialized `getDecimal` case + * per ordinal. + */ + def decimalPrecisionByOrdinal(boundExpr: Expression): Map[Int, Option[DecimalType]] = { + boundExpr + .collect { + case b: BoundReference if b.dataType.isInstanceOf[DecimalType] => + b.ordinal -> b.dataType.asInstanceOf[DecimalType] + } + .groupBy(_._1) + .map { case (ord, pairs) => + val distinct = pairs.map(_._2).toSet + ord -> (if (distinct.size == 1) Some(distinct.head) else None) + } + } + + /** + * Emit every nested class needed for every complex level of every input column. For an + * `ArrayColumnSpec` we emit `InputArray_${path}`; for a `StructColumnSpec` + * `InputStruct_${path}`; for a `MapColumnSpec` `InputMap_${path}` plus the `InputArray` classes + * for the key and value slices (because Spark's `MapData.keyArray()` / `valueArray()` return + * `ArrayData` - same view shape as any other array). + */ + def emitNestedClasses(inputSchema: Seq[ArrowColumnSpec]): String = { + val out = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + collectNestedClasses(s"col$ord", spec, out) + } + out.mkString("\n") + } + + /** + * Emit the kernel's `@Override public ArrayData getArray(int ordinal)` method. Each case reads + * `(startIdx, length)` from the outer `ListVector`'s offsets and allocates a fresh + * `InputArray_col${ord}` view over that slice. + */ + def emitGetArrayMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: ArrayColumnSpec, ord) => + s""" case $ord: { + | int __idx = this.rowIdx; + | int __s = this.col$ord.getElementStartIndex(__idx); + | int __e = this.col$ord.getElementEndIndex(__idx); + | return new InputArray_col$ord(__s, __e - __s); + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getArray out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Emit the kernel's top-level `@Override public MapData getMap(int ordinal)` method when the + * input schema has at least one map-typed column at the top level. + */ + def emitGetMapMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: MapColumnSpec, ord) => + s""" case $ord: { + | int __idx = this.rowIdx; + | int __s = this.col$ord.getElementStartIndex(__idx); + | int __e = this.col$ord.getElementEndIndex(__idx); + | return new InputMap_col$ord(__s, __e - __s); + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getMap out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Emit the kernel's top-level `@Override public InternalRow getStruct(int ordinal, int + * numFields)` method when the input schema has at least one struct-typed column. + */ + def emitGetStructMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: StructColumnSpec, ord) => + s""" case $ord: return new InputStruct_col$ord(this.rowIdx);""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getStruct out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Non-wrapped scalar columns that want a cached data-buffer address for inline unsafe reads. + * `DecimalVector` uses it for the short-precision fast path (`Platform.getLong`); + * `VarCharVector` / `VarBinaryVector` use it as the base address for `UTF8String.fromAddress` / + * `Platform.copyMemory`. See the unsafe-emitter block at the bottom of this file for why we + * inline rather than reuse `CometPlainVector`. + */ + private def needsValueAddrField(cls: Class[_]): Boolean = + cls == classOf[DecimalVector] || + cls == classOf[VarCharVector] || + cls == classOf[VarBinaryVector] + + /** Variable-width columns also want the offset-buffer address cached for `Platform.getInt`. */ + private def needsOffsetAddrField(cls: Class[_]): Boolean = + cls == classOf[VarCharVector] || cls == classOf[VarBinaryVector] + + /** + * Java method name for the null check on a column's typed field. Primitive scalars wrapped in + * [[CometPlainVector]] expose `isNullAt`; Arrow typed fields (complex containers, + * `DecimalVector`, `VarCharVector`, `VarBinaryVector`) expose `isNull`. Both read the validity + * bitmap. + */ + private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match { + case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt" + case _ => "isNull" + } + + private def collectVectorFieldDecls( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + // Primitive scalars at any nesting depth wrap in CometPlainVector for JIT-inlined + // Platform.get* against a cached buffer address. DecimalVector / VarCharVector / + // VarBinaryVector stay on the Arrow typed field with cached data- (and offset-) buffer + // addresses for inline unsafe reads. + val fieldClass = + if (wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName + else sc.vectorClass.getName + out += s"private $fieldClass $path;" + if (needsValueAddrField(sc.vectorClass)) { + out += s"private long ${path}_valueAddr;" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"private long ${path}_offsetAddr;" + } + case ar: ArrayColumnSpec => + out += s"private ${classOf[ListVector].getName} $path;" + collectVectorFieldDecls(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += s"private ${classOf[StructVector].getName} $path;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectVectorFieldDecls(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += s"private ${classOf[MapVector].getName} $path;" + // Key and value vectors live at `${P}_k_e` / `${P}_v_e` so the `InputArray_${P}_k` / + // `InputArray_${P}_v` synthetic classes (which follow the array-element convention of + // reading from `${path}_e`) resolve their element reads correctly. + collectVectorFieldDecls(s"${path}_k_e", mp.key, out) + collectVectorFieldDecls(s"${path}_v_e", mp.value, out) + } + + private def collectCasts( + path: String, + spec: ArrowColumnSpec, + source: String, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + if (wrapsInCometPlainVector(sc.vectorClass)) { + // Wrap in CometPlainVector so per-row reads go through Platform.get* against a final + // long buffer address. JIT inlines the one-liner getters, treating the address as a + // register-cached constant across the process loop. useDecimal128 = true matches + // Spark's 128-bit decimal storage. + out += s"this.$path = new $cometPlainVectorName($source, true);" + } else { + out += s"this.$path = (${sc.vectorClass.getName}) $source;" + } + if (needsValueAddrField(sc.vectorClass)) { + out += s"this.${path}_valueAddr = this.$path.getDataBuffer().memoryAddress();" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"this.${path}_offsetAddr = this.$path.getOffsetBuffer().memoryAddress();" + } + case ar: ArrayColumnSpec => + out += s"this.$path = (${classOf[ListVector].getName}) $source;" + collectCasts(s"${path}_e", ar.element, s"this.$path.getDataVector()", out) + case st: StructColumnSpec => + out += s"this.$path = (${classOf[StructVector].getName}) $source;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectCasts(s"${path}_f$fi", f.child, s"this.$path.getChildByOrdinal($fi)", out) + } + case mp: MapColumnSpec => + // MapVector's data vector is a StructVector with key at child 0 and value at child 1. + // Grab the struct through a local var and pull out the typed children. The key / value + // vectors live at the `_k_e` / `_v_e` paths so the synthetic `InputArray_${P}_k` / + // `InputArray_${P}_v` classes read them via the standard array-element convention. + val structLocal = s"${path}__mapStruct" + out += s"this.$path = (${classOf[MapVector].getName}) $source;" + out += s"${classOf[StructVector].getName} $structLocal = " + + s"(${classOf[StructVector].getName}) this.$path.getDataVector();" + collectCasts(s"${path}_k_e", mp.key, s"$structLocal.getChildByOrdinal(0)", out) + collectCasts(s"${path}_v_e", mp.value, s"$structLocal.getChildByOrdinal(1)", out) + } + + private def collectNestedClasses( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case _: ScalarColumnSpec => () + case ar: ArrayColumnSpec => + out += emitArrayClass(path, ar) + collectNestedClasses(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += emitStructClass(path, st) + st.fields.zipWithIndex.foreach { case (f, fi) => + collectNestedClasses(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += emitMapClass(path) + // Emit InputArray_${path}_k and InputArray_${path}_v: the ArrayData views returned by + // `keyArray()` / `valueArray()`. Each reads from `${classPath}_e` per the array-element + // convention, which maps to the key / value vector at `${path}_k_e` / `${path}_v_e`. + out += emitArrayClass( + s"${path}_k", + ArrayColumnSpec(nullable = true, elementSparkType = mp.keySparkType, element = mp.key)) + out += emitArrayClass( + s"${path}_v", + ArrayColumnSpec( + nullable = true, + elementSparkType = mp.valueSparkType, + element = mp.value)) + // Recurse into the key / value specs at their canonical paths (${path}_k_e / + // ${path}_v_e) so nested complex keys / values get their own nested classes. + collectNestedClasses(s"${path}_k_e", mp.key, out) + collectNestedClasses(s"${path}_v_e", mp.value, out) + } + + /** + * Emit one `InputArray_${path}` nested class. Constructor takes the slice `(startIdx, length)` + * and stores both in `final` fields. Map key / value arrays share this shape over `${path}_k` / + * `${path}_v`. + */ + private def emitArrayClass(path: String, spec: ArrayColumnSpec): String = { + val baseClassName = classOf[CometArrayData].getName + val elemPath = s"${path}_e" + val isNullAt = + s""" @Override + | public boolean isNullAt(int i) { + | return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i); + | }""".stripMargin + val elementGetter = emitArrayElementGetter(path, spec) + s""" private final class InputArray_$path extends $baseClassName { + | private final int startIndex; + | private final int length; + | + | InputArray_$path(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + |$isNullAt + | + |$elementGetter + | } + |""".stripMargin + } + + /** + * Emit the element getter body for a nested `InputArray_${path}`. Scalar element -> direct + * typed read. Complex element -> `getArray(i)` / `getStruct(i, n)` / `getMap(i)` allocates a + * fresh inner view over the appropriate slice. + * + * Reference-typed element getters (`getDecimal` / `getUTF8String` / `getBinary` / `getStruct` / + * `getArray` / `getMap`) prepend `if (isNullAt(i)) return null;` when the element is nullable. + * Reason: Spark's `CodeGenerator.setArrayElement` only emits a caller-side `isNullAt` check + * before `update(i, getX(j))` when `elementType` is a Java primitive; for reference types it + * relies on the source's getter to return `null` itself (Spark's own `ColumnarArray.getBinary` + * does the same). Without this guard, expressions like `Flatten.doGenCode` write our non-null + * shells / empty bytes / garbage decimals where Spark expects null, producing silently-wrong + * values or NPEs downstream. + */ + private def emitArrayElementGetter(path: String, spec: ArrayColumnSpec): String = { + val elemPath = s"${path}_e" + val nullGuard = + if (spec.element.nullable) " if (isNullAt(i)) return null;\n" + else "" + spec.element match { + case _: ScalarColumnSpec => + emitArrayElementScalarGetter(spec.elementSparkType, elemPath, spec.element.nullable) + case _: ArrayColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int i) { + |$nullGuard int __idx = startIndex + i; + | int __s = $elemPath.getElementStartIndex(__idx); + | int __e = $elemPath.getElementEndIndex(__idx); + | return new InputArray_$elemPath(__s, __e - __s); + | }""".stripMargin + case _: StructColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int i, int numFields) { + |$nullGuard return new InputStruct_$elemPath(startIndex + i); + | }""".stripMargin + case _: MapColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int i) { + |$nullGuard int __idx = startIndex + i; + | int __s = $elemPath.getElementStartIndex(__idx); + | int __e = $elemPath.getElementEndIndex(__idx); + | return new InputMap_$elemPath(__s, __e - __s); + | }""".stripMargin + } + } + + /** + * Emit the scalar-element getter override for a nested `InputArray_${path}`. Only the getter + * matching the element type is overridden; any other getter inherits the base class's + * `UnsupportedOperationException`. Reference-typed getters (Decimal / String / Binary) prepend + * the null guard documented on [[emitArrayElementGetter]]. + */ + private def emitArrayElementScalarGetter( + elemType: DataType, + childField: String, + elementNullable: Boolean): String = { + val nullGuard = + if (elementNullable) " if (isNullAt(i)) return null;\n" + else "" + elemType match { + case BooleanType => + s""" @Override + | public boolean getBoolean(int i) { + | return $childField.getBoolean(startIndex + i); + | }""".stripMargin + case ByteType => + s""" @Override + | public byte getByte(int i) { + | return $childField.getByte(startIndex + i); + | }""".stripMargin + case ShortType => + s""" @Override + | public short getShort(int i) { + | return $childField.getShort(startIndex + i); + | }""".stripMargin + case IntegerType | DateType => + s""" @Override + | public int getInt(int i) { + | return $childField.getInt(startIndex + i); + | }""".stripMargin + case LongType | TimestampType | TimestampNTZType => + s""" @Override + | public long getLong(int i) { + | return $childField.getLong(startIndex + i); + | }""".stripMargin + case FloatType => + s""" @Override + | public float getFloat(int i) { + | return $childField.getFloat(startIndex + i); + | }""".stripMargin + case DoubleType => + s""" @Override + | public double getDouble(int i) { + | return $childField.getDouble(startIndex + i); + | }""".stripMargin + case dt: DecimalType => + val body = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + emitDecimalFastBodyUnsafe(s"${childField}_valueAddr", "startIndex + i", " ") + } else { + emitDecimalSlowBody(childField, "startIndex + i", " ") + } + s""" @Override + | public org.apache.spark.sql.types.Decimal getDecimal( + | int i, int precision, int scale) { + |$nullGuard$body + | }""".stripMargin + case _: StringType => + s""" @Override + | public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i) { + |$nullGuard${emitUtf8BodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case BinaryType => + s""" @Override + | public byte[] getBinary(int i) { + |$nullGuard${emitBinaryBodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case other => + throw new UnsupportedOperationException( + s"nested ArrayData: unsupported element type $other") + } + } + + /** + * Emit one `InputStruct_${path}` nested class. Constructor takes `rowIdx` and stores it in a + * `final` field. Scalar getters switch on field ordinal; complex getters allocate fresh inner + * views (offsets computed for array / map children; rowIdx passed through for struct children). + */ + private def emitStructClass(path: String, spec: StructColumnSpec): String = { + val baseClassName = classOf[CometInternalRow].getName + val isNullCases = spec.fields.zipWithIndex.map { + case (f, fi) if !f.nullable => + s" case $fi: return false;" + case (f, fi) => + s" case $fi: return ${path}_f$fi.${nullCheckMethod(f.child)}(this.rowIdx);" + } + val scalarGetters = emitStructScalarGetters(path, spec) + val complexGetters = emitStructComplexGetters(path, spec) + s""" private final class InputStruct_$path extends $baseClassName { + | private final int rowIdx; + | + | InputStruct_$path(int outerRowIdx) { + | this.rowIdx = outerRowIdx; + | } + | + | @Override + | public int numFields() { + | return ${spec.fields.length}; + | } + | + | @Override + | public boolean isNullAt(int ordinal) { + | switch (ordinal) { + |${isNullCases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "InputStruct_$path.isNullAt out of range: " + ordinal); + | } + | } + | + |$scalarGetters + |$complexGetters + | } + |""".stripMargin + } + + // Scalar-read body templates. Each helper emits the per-type read statements parameterised + // on a row-index expression (`idx`), cached buffer addresses (`valueAddr`, `offsetAddr`) for + // unsafe reads, or the Arrow field for the decimal slow path. `ind` is the per-line indent. + // + // The VarChar / VarBinary unsafe emitters below duplicate what CometPlainVector.getUTF8String + // / getBinary do, minus an internal `isNullAt` (redundant: caller already handled it) and + // dereferencing the offset buffer per call (we cache that). Once apache/datafusion-comet#4280 + // (offset-address caching) and #4279 (validity-bitmap byte cache) land upstream, both + // differences disappear and these emitters can be replaced by `CometPlainVector` reuse. + // The decimal-fast variant is independent: compile-time precision specialisation. + + private def emitStructScalarGetters(path: String, spec: StructColumnSpec): String = { + val withOrd = spec.fields.zipWithIndex + val scalarOrd = withOrd.filter { case (f, _) => f.child.isInstanceOf[ScalarColumnSpec] } + + // For nullable reference-typed struct fields, prepend `if (isNullAt(ord)) return null;` to + // honor Spark's contract that `getX(ord)` returns null on null positions for reference + // types. See [[emitArrayElementGetter]] for the same fix on nested array element getters. + def nullGuardForCase(fi: Int, fieldNullable: Boolean): String = + if (fieldNullable) s" if (isNullAt($fi)) return null;\n" + else "" + + def fieldReadScalar(fi: Int, dt: DataType, fieldNullable: Boolean): String = { + val guard = nullGuardForCase(fi, fieldNullable) + dt match { + case BooleanType => + s" case $fi: return ${path}_f$fi.getBoolean(this.rowIdx);" + case ByteType => + s" case $fi: return ${path}_f$fi.getByte(this.rowIdx);" + case ShortType => + s" case $fi: return ${path}_f$fi.getShort(this.rowIdx);" + case IntegerType | DateType => + s" case $fi: return ${path}_f$fi.getInt(this.rowIdx);" + case LongType | TimestampType | TimestampNTZType => + s" case $fi: return ${path}_f$fi.getLong(this.rowIdx);" + case FloatType => + s" case $fi: return ${path}_f$fi.getFloat(this.rowIdx);" + case DoubleType => + s" case $fi: return ${path}_f$fi.getDouble(this.rowIdx);" + case BinaryType => + s""" case $fi: { + |$guard${emitBinaryBodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: StringType => + s""" case $fi: { + |$guard${emitUtf8BodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: DecimalType => + throw new IllegalStateException("decimal handled separately") + case other => + throw new UnsupportedOperationException( + s"nested InputStruct getter: unsupported field type $other") + } + } + + val booleanCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BooleanType => + fieldReadScalar(fi, BooleanType, f.nullable) + } + val byteCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ByteType => + fieldReadScalar(fi, ByteType, f.nullable) + } + val shortCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ShortType => + fieldReadScalar(fi, ShortType, f.nullable) + } + val intCases = scalarOrd.collect { + case (f, fi) if f.sparkType == IntegerType || f.sparkType == DateType => + fieldReadScalar(fi, IntegerType, f.nullable) + } + val longCases = scalarOrd.collect { + case (f, fi) + if f.sparkType == LongType || f.sparkType == TimestampType || + f.sparkType == TimestampNTZType => + fieldReadScalar(fi, LongType, f.nullable) + } + val floatCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == FloatType => + fieldReadScalar(fi, FloatType, f.nullable) + } + val doubleCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == DoubleType => + fieldReadScalar(fi, DoubleType, f.nullable) + } + val binaryCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BinaryType => + fieldReadScalar(fi, BinaryType, f.nullable) + } + val utf8Cases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[StringType] => + fieldReadScalar(fi, f.sparkType, f.nullable) + } + + val decimalCases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[DecimalType] => + val dt = f.sparkType.asInstanceOf[DecimalType] + val field = s"${path}_f$fi" + val body = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + emitDecimalFastBodyUnsafe(s"${field}_valueAddr", "this.rowIdx", " ") + } else { + emitDecimalSlowBody(field, "this.rowIdx", " ") + } + val guard = nullGuardForCase(fi, f.nullable) + s""" case $fi: { + |$guard$body + | }""".stripMargin + } + + Seq( + structSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + structSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + structSwitch("public short getShort(int ordinal)", "getShort", shortCases), + structSwitch("public int getInt(int ordinal)", "getInt", intCases), + structSwitch("public long getLong(int ordinal)", "getLong", longCases), + structSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + structSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + structSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + structSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + structSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + private def emitStructComplexGetters(path: String, spec: StructColumnSpec): String = { + // Same null-guard rationale as `emitArrayElementGetter`: complex-typed (Array / Struct / Map) + // struct field getters must return null for null positions, since Spark's reference-type + // call sites rely on that contract. + def guardLine(fi: Int, fieldNullable: Boolean): String = + if (fieldNullable) s" if (isNullAt($fi)) return null;\n" + else "" + val getArrayCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[ArrayColumnSpec] => + val fieldPath = s"${path}_f$fi" + s""" case $fi: { + |${guardLine(fi, f.nullable)} int __idx = this.rowIdx; + | int __s = $fieldPath.getElementStartIndex(__idx); + | int __e = $fieldPath.getElementEndIndex(__idx); + | return new InputArray_$fieldPath(__s, __e - __s); + | }""".stripMargin + } + val getStructCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[StructColumnSpec] => + val fieldPath = s"${path}_f$fi" + if (f.nullable) { + s""" case $fi: { + |${guardLine( + fi, + f.nullable)} return new InputStruct_$fieldPath(this.rowIdx); + | }""".stripMargin + } else { + s" case $fi: return new InputStruct_$fieldPath(this.rowIdx);" + } + } + val getMapCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[MapColumnSpec] => + val fieldPath = s"${path}_f$fi" + s""" case $fi: { + |${guardLine(fi, f.nullable)} int __idx = this.rowIdx; + | int __s = $fieldPath.getElementStartIndex(__idx); + | int __e = $fieldPath.getElementEndIndex(__idx); + | return new InputMap_$fieldPath(__s, __e - __s); + | }""".stripMargin + } + Seq( + structSwitch( + "public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)", + "getArray", + getArrayCases), + structSwitch( + "public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields)", + "getStruct", + getStructCases), + structSwitch( + "public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)", + "getMap", + getMapCases)).mkString + } + + /** + * Emit one `InputMap_${path}` nested class. Constructor takes the slice `(startIndex, length)`; + * `keyArray()` / `valueArray()` allocate fresh `InputArray_${path}_k` / `InputArray_${path}_v` + * views over the same slice. + */ + private def emitMapClass(path: String): String = { + val baseClassName = classOf[CometMapData].getName + val keyPath = s"${path}_k" + val valPath = s"${path}_v" + s""" private final class InputMap_$path extends $baseClassName { + | private final int startIndex; + | private final int length; + | + | InputMap_$path(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData keyArray() { + | return new InputArray_$keyPath(this.startIndex, this.length); + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData valueArray() { + | return new InputArray_$valPath(this.startIndex, this.length); + | } + | } + |""".stripMargin + } + + private def structSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala new file mode 100644 index 0000000000..efa12416b2 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types._ + +import org.apache.comet.CometArrowAllocator + +/** + * Output-side emitters for the Arrow-direct codegen kernel. Everything that writes a computed + * value into an Arrow output vector lives here: [[allocateOutput]], [[emitOutputWriter]] (the + * entry point for the kernel's top-level write), [[emitWrite]] (recursive per-type write), the + * output vector-class lookup, and the output-side type-support gate. + * + * Paired with [[CometBatchKernelCodegenInput]], which handles the symmetric input side. + */ +private[codegen] object CometBatchKernelCodegenOutput { + + /** + * Allocate an Arrow output vector matching `dataType`. Delegates to [[Utils.toArrowField]] + + * `Field.createVector` for the Spark -> Arrow mapping (handles `MapVector`'s non-null-key and + * non-null-entries invariants). + * + * For variable-length scalar outputs (`StringType`, `BinaryType`), `estimatedBytes` pre-sizes + * the data buffer to avoid mid-loop realloc; ignored for non-`BaseVariableWidthVector` roots, + * and not propagated into nested var-width children (those get default sizing because the + * parent's `allocateNew` resets child buffers). + * + * TODO(nested-varwidth-sizing): thread the estimate into nested var-width children. Arrow + * Java's child-vector hints are allocator-level, so this needs a small recursion or a heuristic + * that overshoots root size into known-leaf children. + * + * TODO(cached-write-buffer-addrs): mirror the input emitter's `_valueAddr` / `_offsetAddr` + * caching. Cache buffer addresses at `process` setup and emit `Platform.putByte` / + * `Platform.copyMemory` for VarChar / VarBinary / Decimal scalar outputs, bypassing `setSafe`'s + * realloc check. Depends on pre-allocated buffers (above). + * + * Closes the vector on any failure so a partially-initialised tree doesn't leak buffers. + */ + def allocateOutput( + dataType: DataType, + name: String, + numRows: Int, + estimatedBytes: Int = -1): FieldVector = + allocateOutput( + Utils.toArrowField(name, dataType, nullable = true, "UTC"), + numRows, + estimatedBytes) + + /** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */ + def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector = { + val vec = field.createVector(CometArrowAllocator).asInstanceOf[FieldVector] + try { + vec.setInitialCapacity(numRows) + vec match { + case v: BaseVariableWidthVector if estimatedBytes > 0 => + v.allocateNew(estimatedBytes.toLong, numRows) + case _ => + vec.allocateNew() + } + vec + } catch { + case t: Throwable => + try vec.close() + catch { + case _: Throwable => () + } + throw t + } + } + + /** + * Returns `(concreteVectorClassName, batchSetup, perRowSnippet)` for the expression's output + * type at the root of the generated kernel. `output` is already cast to + * `concreteVectorClassName` in `process`'s prelude, so `emitWrite`'s complex-type branches can + * hoist child casts straight off `output` without re-casting it per row. + */ + def emitOutputWriter( + dataType: DataType, + valueTerm: String, + ctx: CodegenContext): (String, String, String) = { + val cls = outputVectorClass(dataType) + val emit = emitWrite("output", "i", valueTerm, dataType, ctx) + (cls, emit.setup, emit.perRow) + } + + /** + * Concrete Arrow vector class name for the given output type. The name is used to cast `outRaw` + * to the right type at the top of the generated `process` method, so that subsequent writes + * through `emitWrite` can call vector-specific methods without further casts. + */ + private def outputVectorClass(dataType: DataType): String = dataType match { + case BooleanType => classOf[BitVector].getName + case ByteType => classOf[TinyIntVector].getName + case ShortType => classOf[SmallIntVector].getName + case IntegerType => classOf[IntVector].getName + case LongType => classOf[BigIntVector].getName + case FloatType => classOf[Float4Vector].getName + case DoubleType => classOf[Float8Vector].getName + case _: DecimalType => classOf[DecimalVector].getName + case _: StringType => classOf[VarCharVector].getName + case BinaryType => classOf[VarBinaryVector].getName + case DateType => classOf[DateDayVector].getName + case TimestampType => classOf[TimeStampMicroTZVector].getName + case TimestampNTZType => classOf[TimeStampMicroVector].getName + case _: ArrayType => classOf[ListVector].getName + case _: StructType => classOf[StructVector].getName + case _: MapType => classOf[MapVector].getName + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.outputVectorClass: unsupported output type $other") + } + + /** + * Composable write emitter. Returns an [[OutputEmit]] whose `setup` declares once-per-batch + * typed child-vector casts (hoisted above the `process` loop) and whose `perRow` writes + * `source` into `targetVec` at `idx`. `targetVec` is assumed pre-cast to the right Arrow class + * (root prelude cast or a parent's setup cast). + * + * Scalars emit `perRow` only. Complex types emit both: setup for child casts, perRow for the + * loop / null guards / recursive writes. Inner `emitWrite` setup bubbles up so deep child casts + * land at the batch prelude. + */ + private def emitWrite( + targetVec: String, + idx: String, + source: String, + dataType: DataType, + ctx: CodegenContext): OutputEmit = dataType match { + case BooleanType => + OutputEmit("", s"$targetVec.set($idx, $source ? 1 : 0);") + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType | + TimestampType | TimestampNTZType => + // All scalar primitives and date/time types share the direct `set(idx, value)` shape. + // Spark's codegen already emits the correct primitive Java type for each; Arrow's + // typed vectors accept the matching primitive in their `set` overloads. + OutputEmit("", s"$targetVec.set($idx, $source);") + case dt: DecimalType => + // Optimization: DecimalOutputShortFastPath. + // For precision <= 18 the unscaled value fits in a signed long; pass it straight to + // `DecimalVector.setSafe(int, long)` and skip the `java.math.BigDecimal` allocation + // `setSafe(int, BigDecimal)` requires. For p > 18 the BigDecimal path is unavoidable. + val write = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + s"$targetVec.setSafe($idx, $source.toUnscaledLong());" + } else { + s"$targetVec.setSafe($idx, $source.toJavaBigDecimal());" + } + OutputEmit("", write) + case _: StringType => + // Optimization: Utf8OutputOnHeapShortcut. + // `UTF8String` is internally a `(base, offset, numBytes)` view. When the base is a + // `byte[]` (common case: Spark string functions allocate results on-heap), pass the + // existing byte[] directly to `VarCharVector.setSafe(int, byte[], int, int)` via the + // encoded offset and skip the redundant `getBytes()` allocation. Off-heap passthrough + // (rare on output side) falls back to `getBytes()`. + // + // TODO(utf8-unsafe-write): the output-side equivalent of the input emitter's + // `UTF8String.fromAddress` zero-copy read would cache the data buffer address once per + // batch and write via `Platform.copyMemory` + manual offset/validity buffer updates, + // bypassing `setSafe`'s realloc check. Coupled with `cached-write-buffer-addrs` and a + // pre-allocated buffer (root-only `estimatedBytes` today). Not done because perf payoff + // is unmeasured against this PR's workloads. + val bBase = ctx.freshName("utfBase") + val bLen = ctx.freshName("utfLen") + val bArr = ctx.freshName("utfArr") + OutputEmit( + "", + s"""Object $bBase = $source.getBaseObject(); + |int $bLen = $source.numBytes(); + |if ($bBase instanceof byte[]) { + | $targetVec.setSafe($idx, (byte[]) $bBase, + | (int) ($source.getBaseOffset() + | - org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET), + | $bLen); + |} else { + | byte[] $bArr = $source.getBytes(); + | $targetVec.setSafe($idx, $bArr, 0, $bArr.length); + |}""".stripMargin) + case BinaryType => + // Spark's BinaryType value is already a `byte[]`. + OutputEmit("", s"$targetVec.setSafe($idx, $source, 0, $source.length);") + case ArrayType(elementType, containsNull) => + // Complex-type output: recursive per-row write. + // Spark's `doGenCode` for ArrayType-returning expressions produces an `ArrayData` value + // (usually `GenericArrayData` / `UnsafeArrayData`). We iterate its elements, write each + // one into the Arrow `ListVector`'s child, and bracket with `startNewValue` / + // `endValue`. The element write recurses through `emitWrite` on the list's child vector, + // so any scalar we support becomes a valid array element. Nested complex types (Array of + // Array, Array of Struct) work by the same recursion. `targetVec` is a `ListVector` at + // the call site (either `output` at root or a hoisted child cast); we only need to cast + // its data vector, and that cast goes into setup. + // + // Optimization: NullableElementElision. When `containsNull == false`, the element + // `isNullAt` guard is dead by Spark's own type-system contract, so we drop it at source + // level rather than relying on JIT folding. + val childVar = ctx.freshName("outListChild") + val childClass = outputVectorClass(elementType) + val arrVar = ctx.freshName("arr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val elemSource = emitSpecializedGetterExpr(arrVar, jVar, elementType) + val inner = emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx) + val setup = + (s"$childClass $childVar = ($childClass) $targetVec.getDataVector();" +: + Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + val elementWrite = if (containsNull) { + s"""if ($arrVar.isNullAt($jVar)) { + | $childVar.setNull($childIdx + $jVar); + | } else { + | ${inner.perRow} + | }""".stripMargin + } else { + inner.perRow + } + val perRow = + s"""org.apache.spark.sql.catalyst.util.ArrayData $arrVar = $source; + |int $nVar = $arrVar.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | $elementWrite + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case st: StructType => + // Complex-type output: recursive per-row write to a StructVector. + // Spark's `doGenCode` for StructType-returning expressions produces an `InternalRow` + // value (`GenericInternalRow` / `UnsafeRow` / ScalaUDF encoder output). Typed child-vector + // casts are hoisted to setup (once per batch); the per-row body references the hoisted + // names. `StructVector` writes are flat-indexed (same `$idx` as the struct's outer slot). + // + // Branchless optimization: for each field whose `nullable == false` on the + // [[StructType]], we skip the `row.isNullAt($fi)` guard at source level. Non-nullable + // fields in Spark are a contract that the producer does not emit nulls for that field, + // and matching that contract here lets HotSpot emit a straight write path per field + // rather than a branch. + val rowVar = ctx.freshName("row") + val perField = st.fields.zipWithIndex.map { case (field, fi) => + val childVar = ctx.freshName("outStructChild") + val childClass = outputVectorClass(field.dataType) + val childDecl = + s"$childClass $childVar = ($childClass) $targetVec.getChildByOrdinal($fi);" + val fieldSource = emitSpecializedGetterExpr(rowVar, fi.toString, field.dataType) + val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx) + val write = + if (!field.nullable) { + inner.perRow + } else { + s"""if ($rowVar.isNullAt($fi)) { + | $childVar.setNull($idx); + |} else { + | ${inner.perRow} + |}""".stripMargin + } + val perFieldSetup = (Seq(childDecl) ++ Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + (perFieldSetup, write) + } + val setup = perField.map(_._1).mkString("\n") + val perFieldWrites = perField.map(_._2).mkString("\n") + val perRow = + s"""org.apache.spark.sql.catalyst.InternalRow $rowVar = $source; + |$targetVec.setIndexDefined($idx); + |$perFieldWrites""".stripMargin + OutputEmit(setup, perRow) + case mt: MapType => + // Complex-type output: recursive per-row write to a MapVector. + // Spark's `doGenCode` for MapType-returning expressions produces a `MapData` value + // (`ArrayBasedMapData` / `UnsafeMapData` / ScalaUDF encoder output). Typed child-vector + // casts for the entries struct and the key/value children are hoisted to setup (once per + // batch); the per-row body references them. + // + // Per-row shape: + // 1. Read keyArray / valueArray from the MapData source. + // 2. Open a new map entry via `startNewValue(idx)`; returns the base index into the + // entries StructVector for this row's key/value pairs. + // 3. For each key/value pair: set the entries struct slot defined (map values can be + // null, but the struct slot itself is defined), write the key (always non-null by + // Spark/Arrow invariant), then write the value with a null-guard on + // `vals.isNullAt(j)`. Both writes recurse through `emitWrite`. + // 4. Close the map entry with `endValue(idx, n)`. + val entriesVar = ctx.freshName("outMapEntries") + val keyVar = ctx.freshName("outMapKey") + val valVar = ctx.freshName("outMapVal") + val mapSrc = ctx.freshName("mapSrc") + val keyArr = ctx.freshName("keyArr") + val valArr = ctx.freshName("valArr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val structClass = classOf[StructVector].getName + val keyClass = outputVectorClass(mt.keyType) + val valClass = outputVectorClass(mt.valueType) + val keySrcExpr = emitSpecializedGetterExpr(keyArr, jVar, mt.keyType) + val valSrcExpr = emitSpecializedGetterExpr(valArr, jVar, mt.valueType) + val keyEmit = emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx) + val valEmit = emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx) + val setup = + (Seq( + s"$structClass $entriesVar = ($structClass) $targetVec.getDataVector();", + s"$keyClass $keyVar = ($keyClass) $entriesVar.getChildByOrdinal(0);", + s"$valClass $valVar = ($valClass) $entriesVar.getChildByOrdinal(1);") ++ + Seq(keyEmit.setup, valEmit.setup).filter(_.nonEmpty)).mkString("\n") + val valueWrite = if (mt.valueContainsNull) { + s"""if ($valArr.isNullAt($jVar)) { + | $valVar.setNull($childIdx + $jVar); + | } else { + | ${valEmit.perRow} + | }""".stripMargin + } else { + valEmit.perRow + } + val perRow = + s"""org.apache.spark.sql.catalyst.util.MapData $mapSrc = $source; + |org.apache.spark.sql.catalyst.util.ArrayData $keyArr = $mapSrc.keyArray(); + |org.apache.spark.sql.catalyst.util.ArrayData $valArr = $mapSrc.valueArray(); + |int $nVar = $mapSrc.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | $entriesVar.setIndexDefined($childIdx + $jVar); + | ${keyEmit.perRow} + | $valueWrite + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitWrite: unsupported output type $other") + } + + /** + * Java expression that reads a typed value out of a Spark `SpecializedGetters` reference (which + * both `ArrayData` and `InternalRow` implement) at a given ordinal/index. Used by the + * `ArrayType` and `StructType` branches of [[emitWrite]] to source each element / field for its + * recursive inner write. + */ + private def emitSpecializedGetterExpr(target: String, idx: String, elemType: DataType): String = + elemType match { + case BooleanType => s"$target.getBoolean($idx)" + case ByteType => s"$target.getByte($idx)" + case ShortType => s"$target.getShort($idx)" + case IntegerType | DateType => s"$target.getInt($idx)" + case LongType | TimestampType | TimestampNTZType => s"$target.getLong($idx)" + case FloatType => s"$target.getFloat($idx)" + case DoubleType => s"$target.getDouble($idx)" + case dt: DecimalType => s"$target.getDecimal($idx, ${dt.precision}, ${dt.scale})" + case _: StringType => s"$target.getUTF8String($idx)" + case BinaryType => s"$target.getBinary($idx)" + case ArrayType(_, _) => s"$target.getArray($idx)" + case _: MapType => s"$target.getMap($idx)" + case _: StructType => + val numFields = elemType.asInstanceOf[StructType].fields.length + s"$target.getStruct($idx, $numFields)" + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitSpecializedGetterExpr: unsupported type $other") + } + + /** + * Split output for a complex-type write: `setup` holds once-per-batch declarations (typed + * child-vector casts) and lives outside the per-row for-loop; `perRow` holds the statements + * executed for each row. Scalar writes have empty setup. + */ + private case class OutputEmit(setup: String, perRow: String) +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala b/common/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala new file mode 100644 index 0000000000..e94ac5dea2 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Throwing-default base for [[InternalRow]] in the Arrow-direct codegen kernel. Subclasses + * override only the getters their input shape needs; centralising the throws absorbs forward- + * compat breakage when Spark adds abstract methods. + * + * Two consumers: the compiled kernel itself (the orchestrator sets `ctx.INPUT_ROW = "row"` and + * aliases `InternalRow row = this;` so `BoundReference.genCode` reads against `this`); and + * `InputStruct_${path}` nested classes that back `getStruct(ord, n)`. + * + * Siblings [[CometArrayData]] (used by `InputArray_*`) and [[CometMapData]] (used by + * `InputMap_*`) cover the other two Spark data-shape abstractions. The `get(ordinal, dataType)` + * dispatch shared with `CometArrayData` lives in [[CometSpecializedGettersDispatch]]. + */ +abstract class CometInternalRow extends InternalRow with CometInternalRowShim { + + override def numFields: Int = unsupported("numFields") + + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + + override def get(ordinal: Int, dataType: DataType): AnyRef = + CometSpecializedGettersDispatch.get(this, ordinal, dataType) + + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + + override def getByte(ordinal: Int): Byte = unsupported("getByte") + + override def getShort(ordinal: Int): Short = unsupported("getShort") + + override def getInt(ordinal: Int): Int = unsupported("getInt") + + override def getLong(ordinal: Int): Long = unsupported("getLong") + + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + + override def update(i: Int, value: Any): Unit = unsupported("update") + + override def copy(): InternalRow = unsupported("copy") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this row shape") +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometMapData.scala b/common/src/main/scala/org/apache/comet/codegen/CometMapData.scala new file mode 100644 index 0000000000..9fb716ff04 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometMapData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} + +/** + * Throwing-default base for [[MapData]] in the Arrow-direct codegen kernel. Codegen-emitted + * `InputMap_${path}` subclasses override `numElements`, `keyArray`, and `valueArray`. + * + * Consumer: `InputMap_${path}` nested classes per `MapType` input column. They back `getMap(ord)` + * and route `keyArray()` / `valueArray()` through `InputArray_*` views (instances of + * [[CometArrayData]]) over the same backing key / value vectors. + * + * Sibling shims [[CometInternalRow]] and [[CometArrayData]] cover row and array shapes. `MapData` + * does not extend `SpecializedGetters`, so this base does not mix in + * [[org.apache.comet.shims.CometInternalRowShim]] or delegate to + * [[CometSpecializedGettersDispatch]]. + */ +abstract class CometMapData extends MapData { + + override def keyArray(): ArrayData = unsupported("keyArray") + + override def valueArray(): ArrayData = unsupported("valueArray") + + override def copy(): MapData = unsupported("copy") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this map shape") + + override def toString(): String = { + val n = + try numElements().toString + catch { + case _: Throwable => "?" + } + s"${getClass.getSimpleName}(numElements=$n)" + } + + override def numElements(): Int = unsupported("numElements") +} diff --git a/common/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala b/common/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala new file mode 100644 index 0000000000..4ca0b22933 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types._ + +/** + * Shared `SpecializedGetters.get(ordinal, dataType)` dispatch used by [[CometInternalRow]] and + * [[CometArrayData]]. Spark codegen paths (notably `SafeProjection` for deserializing `ScalaUDF` + * struct arguments) and interpreted-eval fallbacks (`ArrayDistinct.nullSafeEval` etc.) call the + * generic `get` instead of the typed getter, so both kernel-side subclasses need a non-throwing + * implementation. The body would be byte-for-byte the same in both classes; centralising it here + * keeps them in sync. + * + * Complex types (`StructType` / `ArrayType` / `MapType`) return whatever the typed getter + * returns. The codegen template allocates a fresh `InputStruct_*` / `InputArray_*` / `InputMap_*` + * with `final` slice fields per call (`ColumnarRow`-style), so retain-by-reference consumers like + * `OpenHashSet` get distinct identities and lazy reads work. + */ +private[codegen] object CometSpecializedGettersDispatch { + + def get(g: SpecializedGetters, ordinal: Int, dataType: DataType): AnyRef = { + if (g.isNullAt(ordinal)) return null + dataType match { + case BooleanType => java.lang.Boolean.valueOf(g.getBoolean(ordinal)) + case ByteType => java.lang.Byte.valueOf(g.getByte(ordinal)) + case ShortType => java.lang.Short.valueOf(g.getShort(ordinal)) + case IntegerType | DateType => java.lang.Integer.valueOf(g.getInt(ordinal)) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.valueOf(g.getLong(ordinal)) + case FloatType => java.lang.Float.valueOf(g.getFloat(ordinal)) + case DoubleType => java.lang.Double.valueOf(g.getDouble(ordinal)) + case _: StringType => g.getUTF8String(ordinal) + case BinaryType => g.getBinary(ordinal) + case dt: DecimalType => g.getDecimal(ordinal, dt.precision, dt.scale) + case st: StructType => g.getStruct(ordinal, st.size) + case _: ArrayType => g.getArray(ordinal) + case _: MapType => g.getMap(ordinal) + case other => + throw new UnsupportedOperationException( + s"${g.getClass.getSimpleName}: get for dataType $other not implemented") + } + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala deleted file mode 100644 index 5e020ae74a..0000000000 --- a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.udf - -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.sql.catalyst.expressions.Expression - -/** - * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan - * time the serde layer registers a lambda expression under a unique key; at execution time the - * UDF retrieves it by that key (passed as a scalar argument). - */ -object CometLambdaRegistry { - - private val registry = new ConcurrentHashMap[String, Expression]() - - def register(expression: Expression): String = { - val key = UUID.randomUUID().toString - registry.put(key, expression) - key - } - - def get(key: String): Expression = { - val expr = registry.get(key) - if (expr == null) { - throw new IllegalStateException( - s"Lambda expression not found in registry for key: $key. " + - "This indicates a lifecycle issue between plan creation and execution.") - } - expr - } - - def remove(key: String): Unit = { - registry.remove(key) - } - - // Visible for testing - def size(): Int = registry.size() -} diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 5b6652d90a..6b435c4064 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -30,12 +30,18 @@ import org.apache.arrow.vector.ValueVector * - The returned vector's length must match `numRows`. * * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. - * UDFs that always have at least one batch-length input can derive length from the inputs and - * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF) - * need `numRows` to know how many rows to produce. + * UDFs that always have at least one batch-length input can read length from it and ignore + * `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the + * codegen dispatcher) need `numRows` to know how many rows to produce. * - * Implementations must have a public no-arg constructor and must be stateless: a single instance - * per class is cached and shared across native worker threads for the lifetime of the JVM. + * Implementations must have a public no-arg constructor. A fresh instance is created per Spark + * task attempt per class and reused for every call within that task. Instances may hold per-task + * state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task + * completion. Do not hold state that must persist across tasks. + * + * At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future + * per partition and Tokio polls one future per worker, so the per-task instance is never touched + * concurrently even if the task's future migrates between Tokio workers across batches. */ trait CometUDF { def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector diff --git a/common/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala b/common/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala new file mode 100644 index 0000000000..edfd3175d9 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf.codegen + +import java.nio.ByteBuffer +import java.util.{Collections, LinkedHashMap} +import java.util.concurrent.atomic.AtomicLong + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types.{BinaryType, DataType, StringType} + +import org.apache.comet.codegen.{CometBatchKernel, CometBatchKernelCodegen} +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} +import org.apache.comet.udf.CometUDF + +/** + * Arrow-direct codegen dispatcher. For each (bound `Expression`, input Arrow schema) pair, + * compiles a specialized [[CometBatchKernel]] on first encounter and caches it; subsequent + * batches with the same shape reuse the compile. + * + * Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound Expression bytes. + * Args 1..N are the data columns the `BoundReference`s read, in ordinal order. The bytes + * self-describe the expression so the path works in cluster mode without executor-side state. + * + * Three caches at different scopes: the JVM-wide compile cache (`kernelCache` on the companion); + * the per-task UDF-instance cache in `CometUdfBridge.INSTANCES`; and per-partition kernel state + * on this instance (`activeKernel`, `activeKey`, `activePartition`) managed by [[ensureKernel]]. + * Each layer covers a distinct lifetime: JVM (compiled bytecode, immutable), task (UDF instance, + * isolated from worker reuse), partition (kernel mutable state for `Rand` / + * `MonotonicallyIncreasingID` / etc.). + */ +class CometScalaUDFCodegen extends CometUDF { + + /** + * Per-partition kernel instance cache. The compile cache stores the compiled `GeneratedClass`; + * the kernel '''instance''' holds per-row mutable state (`Rand`'s `XORShiftRandom`, + * `MonotonicallyIncreasingID`'s counter, etc.) that must advance across batches in one + * partition and reset across partitions. Allocating per partition gets that right. + * + * Plain `var`s are safe: this dispatcher is per-task (`CometUdfBridge.INSTANCES` keys by + * `taskAttemptId`) and Spark drives one partition per task, so [[ensureKernel]] never sees + * concurrent access. A different partition or expression triggers a fresh allocation. + */ + private var activeKernel: CometBatchKernel = _ + private var activeKey: CometScalaUDFCodegen.CacheKey = _ + private var activePartition: Int = -1 + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require( + inputs.length >= 1, + "CometScalaUDFCodegen requires at least 1 input (serialized expression), " + + s"got ${inputs.length}") + val exprVec = inputs(0).asInstanceOf[VarBinaryVector] + require( + exprVec.getValueCount >= 1 && !exprVec.isNull(0), + "CometScalaUDFCodegen requires non-null serialized expression bytes at arg 0") + val bytes = exprVec.get(0) + + // TODO(dict-encoded): kernels assume materialized inputs; dict-encoded vectors would fail the + // cast in `specFor` below. Fix is to materialize at the dispatcher (via + // `CDataDictionaryProvider`) or widen `emitTypedGetters` with a dict-index + lookup path. + + val numDataCols = inputs.length - 1 + val dataCols = new Array[ValueVector](numDataCols) + val specs = new Array[ArrowColumnSpec](numDataCols) + var di = 0 + while (di < numDataCols) { + val v = inputs(di + 1) + dataCols(di) = v + specs(di) = specFor(v) + di += 1 + } + val n = numRows + val specsSeq = specs.toIndexedSeq + + val key = CometScalaUDFCodegen.CacheKey(ByteBuffer.wrap(bytes), specsSeq) + val entry = CometScalaUDFCodegen.lookupOrCompile(key, bytes, specsSeq) + + val partitionId = CometScalaUDFCodegen.currentPartitionIndex() + val kernel = ensureKernel(entry.compiled, key, partitionId) + + val out = CometBatchKernelCodegen.allocateOutput( + entry.outputField, + n, + estimatedOutputBytes(entry.outputType, dataCols)) + try { + kernel.process(dataCols, out, n) + out.setValueCount(n) + out + } catch { + case t: Throwable => + try out.close() + catch { + case _: Throwable => () + } + throw t + } + } + + private def ensureKernel( + compiled: CometBatchKernelCodegen.CompiledKernel, + key: CometScalaUDFCodegen.CacheKey, + partitionId: Int): CometBatchKernel = { + if (activeKernel == null || activePartition != partitionId || activeKey != key) { + activeKernel = compiled.newInstance() + activeKernel.init(partitionId) + activeKey = key + activePartition = partitionId + } + activeKernel + } + + /** + * Did any row in this batch set the null bit? Carried per column on the cache key, so batches + * with different nullability map to different kernels (no correctness risk). The + * `nullable=false` compile emits `return false` from `isNullAt` and, paired with the + * `BoundReference` tree rewrite in `lookupOrCompile`, lets Spark skip the null branch at source + * level rather than via JIT folding. + * + * Workloads that flip nullability frequently can cache up to `2^numCols` kernel variants per + * expression; common-case stable nullability stays at one. + */ + private def nullable(v: ValueVector): Boolean = v.getNullCount != 0 + + /** + * Build the compile-time spec for one input Arrow vector. Recurses on complex types; scalars + * produce a [[ScalarColumnSpec]] carrying the concrete Arrow vector class and nullability. + * Spark `DataType`s on complex children come from [[Utils.fromArrowField]] so the Arrow -> + * Spark mapping stays in one place. + */ + private def specFor(v: ValueVector): ArrowColumnSpec = v match { + case map: MapVector => + // MapVector extends ListVector; match it first. Its data vector is a StructVector with + // child 0 = key and child 1 = value. + val struct = map.getDataVector.asInstanceOf[StructVector] + val keyVec = struct.getChildByOrdinal(0).asInstanceOf[ValueVector] + val valueVec = struct.getChildByOrdinal(1).asInstanceOf[ValueVector] + MapColumnSpec( + nullable = nullable(map), + keySparkType = Utils.fromArrowField(keyVec.getField), + valueSparkType = Utils.fromArrowField(valueVec.getField), + key = specFor(keyVec), + value = specFor(valueVec)) + case list: ListVector => + val child = list.getDataVector + ArrayColumnSpec(nullable(list), Utils.fromArrowField(child.getField), specFor(child)) + case struct: StructVector => + val fieldSpecs = (0 until struct.size()).map { fi => + val childVec = struct.getChildByOrdinal(fi).asInstanceOf[ValueVector] + val field = struct.getField.getChildren.get(fi) + StructFieldSpec( + name = field.getName, + sparkType = Utils.fromArrowField(field), + nullable = field.isNullable, + child = specFor(childVec)) + } + StructColumnSpec(nullable(struct), fieldSpecs) + case _: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | + _: Float4Vector | _: Float8Vector | _: DecimalVector | _: VarCharVector | + _: VarBinaryVector | _: DateDayVector | _: TimeStampMicroVector | + _: TimeStampMicroTZVector => + ScalarColumnSpec(v.getClass.asInstanceOf[Class[_ <: ValueVector]], nullable(v)) + case other => + throw new UnsupportedOperationException( + s"CometScalaUDFCodegen: unsupported Arrow vector ${other.getClass.getSimpleName}") + } + + /** + * Estimate output byte capacity for variable-length output types. Sums the data-buffer sizes of + * variable-length input vectors as an upper bound for typical transform expressions (replace, + * upper, lower, substring, concat on the same inputs). Underestimates are still corrected by + * `setSafe`; this just reduces the odds of mid-loop reallocation. + */ + private def estimatedOutputBytes(outputType: DataType, dataCols: Array[ValueVector]): Int = { + outputType match { + case _: StringType | _: BinaryType => + var sum = 0 + var i = 0 + while (i < dataCols.length) { + dataCols(i) match { + case v: BaseVariableWidthVector => sum += v.getDataBuffer.writerIndex().toInt + case _ => // no size hint for fixed-width vector types + } + i += 1 + } + sum + case _ => -1 + } + } +} + +object CometScalaUDFCodegen { + + private val CacheCapacity: Int = 128 + private val kernelCache: java.util.Map[CacheKey, CacheEntry] = + Collections.synchronizedMap( + new LinkedHashMap[CacheKey, CacheEntry](CacheCapacity, 0.75f, true) { + override def removeEldestEntry( + eldest: java.util.Map.Entry[CacheKey, CacheEntry]): Boolean = + size() > CacheCapacity + }) + // Observability counters. Incremented under the `kernelCache.synchronized` block in + // `lookupOrCompile` so counter increments and cache mutations cannot interleave. Read via + // [[stats]]; reset via [[resetStats]] for tests. + private val compileCount = new AtomicLong(0) + private val cacheHitCount = new AtomicLong(0) + + /** Returns a snapshot of cache counters and current size. Cheap; safe to call anytime. */ + def stats(): DispatcherStats = + DispatcherStats(compileCount.get(), cacheHitCount.get(), kernelCache.size()) + + /** Reset counters to zero. Leaves the compile cache intact. Intended for tests. */ + def resetStats(): Unit = { + compileCount.set(0) + cacheHitCount.set(0) + } + + /** + * Test-facing snapshot of compiled kernel signatures: `(input Arrow vector classes in ordinal + * order, output Spark DataType)` per cache entry. Lets tests assert specialization shape, not + * just result correctness. Drops `ArrowColumnSpec.nullable` so a single assertion matches both + * `nullable=true` and `nullable=false` variants of the same expression. + */ + def snapshotCompiledSignatures(): Set[(IndexedSeq[Class[_ <: ValueVector]], DataType)] = { + kernelCache.synchronized { + import scala.jdk.CollectionConverters._ + kernelCache + .entrySet() + .asScala + .iterator + .map { e => + (e.getKey.specs.map(_.vectorClass), e.getValue.outputType) + } + .toSet + } + } + + private def lookupOrCompile( + key: CacheKey, + bytes: Array[Byte], + specs: IndexedSeq[ArrowColumnSpec]): CacheEntry = { + kernelCache.synchronized { + val existing = kernelCache.get(key) + if (existing != null) { + cacheHitCount.incrementAndGet() + existing + } else { + // Use a classloader that can see Spark classes. The Comet native runtime calls us on a + // Tokio worker thread where the context classloader may not be set to Spark's task + // loader, so fall back to the loader that loaded `Expression` itself if needed. + val loader = Option(Thread.currentThread().getContextClassLoader) + .getOrElse(classOf[Expression].getClassLoader) + val rawExpr = SparkEnv.get.closureSerializer + .newInstance() + .deserialize[Expression](ByteBuffer.wrap(bytes), loader) + // Tighten BoundReference.nullable based on the observed batch. The plan-time value is + // conservative (the column may be null somewhere in the query's execution), but for + // this specific batch we know. Rewriting lets Spark's `BoundReference.genCode` skip the + // `isNull` branch at source level rather than leaving it to JIT constant-folding. + // Correctness is preserved by the cache key: a later batch with nulls on this column has + // a different `specs`, so it hits a different kernel compiled with nullable=true. + val boundExpr = rewriteBoundReferences(rawExpr, specs) + val compiled = CometBatchKernelCodegen.compile(boundExpr, specs) + val outputField = + Utils.toArrowField("codegen_result", boundExpr.dataType, nullable = true, "UTC") + val entry = CacheEntry(compiled, boundExpr.dataType, outputField) + kernelCache.put(key, entry) + compileCount.incrementAndGet() + entry + } + } + } + + /** + * Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to + * `nullable=false` when the corresponding input column in `specs` is non-nullable for this + * batch. Only tightens; never relaxes. Expressions outside the `BoundReference` leaves are + * unchanged. + */ + private def rewriteBoundReferences( + expr: Expression, + specs: IndexedSeq[ArrowColumnSpec]): Expression = { + expr.transform { + case BoundReference(ord, dt, true) + if ord >= 0 && ord < specs.length && !specs(ord).nullable => + BoundReference(ord, dt, nullable = false) + // Fall through unchanged: non-BoundReference nodes and BoundReferences that are already + // non-nullable or point at a nullable column in this batch. + case other => other + } + } + + /** + * Partition index for the generated kernel's `init`. Expressions whose `doGenCode` calls + * `addPartitionInitializationStatement` (e.g. `Rand`, `Randn`, `Uuid`) reseed mutable state + * from this. Falls back to 0 when the dispatcher is exercised outside a Spark task (unit tests) + * so an absent `TaskContext` does not fail the call; the result is still deterministic for that + * fallback. + */ + private def currentPartitionIndex(): Int = + Option(TaskContext.get()).map(_.partitionId()).getOrElse(0) + + /** + * Cache key: serialized expression bytes plus per-column compile-time invariants. + * + * `hashCode` walks `bytesKey` per lookup, so for large ScalaUDF closures it scales with closure + * size. TODO(perf-cache-key): if this becomes hot, options are a driver-precomputed hash piggy- + * backed through the proto, a per-instance last-key memoization, or a two-tier cache keyed on + * the generated source string. + */ + final case class CacheKey(bytesKey: ByteBuffer, specs: IndexedSeq[ArrowColumnSpec]) + + /** + * Snapshot of dispatcher cache counters and current size. Intended for tests, logging, and + * future integration with Spark SQL metrics. Not thread-synchronized across the three fields + * (each read is atomic, but they are not read atomically together); snapshots taken during + * concurrent activity may show a consistent individual-field view but a slightly inconsistent + * combined view. Fine for reporting, not for assertions that require cross-field invariants. + */ + final case class DispatcherStats(compileCount: Long, cacheHitCount: Long, cacheSize: Int) { + def hitRate: Double = + if (totalLookups == 0) 0.0 else cacheHitCount.toDouble / totalLookups.toDouble + + def totalLookups: Long = compileCount + cacheHitCount + } + + private case class CacheEntry( + compiled: CometBatchKernelCodegen.CompiledKernel, + outputType: DataType, + outputField: Field) +} diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala b/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..3d039879d5 --- /dev/null +++ b/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** + * Per-profile view of expression traits that shifted shape across Spark versions. Spark 3.x has a + * `NullIntolerant` marker trait and no scalar-expression `Stateful` concept at all (the notion + * was added in 4.x as a boolean method on `Expression`). Routing checks through one shim lets the + * dispatcher ask "is this expression null-intolerant / stateful" without sprinkling version + * pattern matches through the codebase. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.isInstanceOf[NullIntolerant] + + // No scalar `Stateful` trait in 3.x. Aggregate/window/generator stateful cases are rejected + // elsewhere in `canHandle`, so treating all scalar expressions as non-stateful here is + // conservative-correct on this profile. + def isStateful(expr: Expression): Boolean = false + + // No collation / `ResolvedCollation` concept in 3.x, so no `Unevaluable` leaf slips past the + // dispatcher's guard here. + def isCodegenInertUnevaluable(expr: Expression): Boolean = false +} diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..e71d301d48 --- /dev/null +++ b/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +/** + * Per-profile extension point mixed into `CometInternalRow` and `CometArrayData`. Spark 4.x added + * new abstract getters on `SpecializedGetters` (`getVariant` in 4.0, `getGeography` and + * `getGeometry` in 4.1) that both `InternalRow` and `ArrayData` concrete subclasses must + * implement. Spark 3.x has none of these; this trait is empty so the shared classes compile + * unchanged on that profile. + */ +trait CometInternalRowShim diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..20c6d47816 --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.VariantVal + +/** + * Throwing-default implementations for `SpecializedGetters` methods that were added in Spark 4.0: + * `getVariant`. The Janino-generated kernel subclasses `CometInternalRow` (rows) and + * `CometArrayData` (array inputs), and each must satisfy every abstract method on the interface; + * without these defaults the compiled class fails its abstract-method check at class-load time. + * `GeographyVal` and `GeometryVal` were added in 4.1, so this profile's shim does not override + * those getters. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") +} diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..3d277e7505 --- /dev/null +++ b/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing-default implementations for `SpecializedGetters` methods added in Spark 4.x: + * `getVariant` (4.0), `getGeography` and `getGeometry` (4.1). The Janino-generated kernel + * subclasses `CometInternalRow` (rows) and `CometArrayData` (array inputs), and each must satisfy + * every abstract method on the interface; without these defaults the compiled class fails its + * abstract-method check at class-load time. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..3d277e7505 --- /dev/null +++ b/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing-default implementations for `SpecializedGetters` methods added in Spark 4.x: + * `getVariant` (4.0), `getGeography` and `getGeometry` (4.1). The Janino-generated kernel + * subclasses `CometInternalRow` (rows) and `CometArrayData` (array inputs), and each must satisfy + * every abstract method on the interface; without these defaults the compiled class fails its + * abstract-method check at class-load time. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala b/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..2d86258014 --- /dev/null +++ b/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, ResolvedCollation} + +/** + * Spark 4.x replaced the `NullIntolerant` marker trait with a boolean method on `Expression`, and + * introduced a `stateful` boolean method covering scalar expressions that carry per-row state + * (e.g. `Rand`, `Uuid`). Neither concept exists as a trait in 4.x, so pattern matches against + * them would fail to compile. This shim routes the checks through the method form. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.nullIntolerant + def isStateful(expr: Expression): Boolean = expr.stateful + + // `ResolvedCollation` is an `Unevaluable` leaf that only lives in `Collate.collation` as a + // type-level marker. `Collate.genCode` passes through to its child and never touches the + // collation slot, so the leaf is never invoked in generated code. Spark 4.1 analyzes it away, + // but 4.0 leaves it in the tree, so the dispatcher's `Unevaluable` guard trips on 4.0 without + // this exemption. + def isCodegenInertUnevaluable(expr: Expression): Boolean = expr match { + case _: ResolvedCollation => true + case _ => false + } +} diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 314a0a51bd..ea4a59a46f 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -43,6 +43,7 @@ to read more. Supported Data Types Supported Operators Supported Expressions + ScalaUDF Codegen Dispatch Configuration Settings Compatibility Guide Understanding Comet Plans diff --git a/docs/source/user-guide/latest/jvm_udf_dispatch.md b/docs/source/user-guide/latest/jvm_udf_dispatch.md new file mode 100644 index 0000000000..f5e9e807ef --- /dev/null +++ b/docs/source/user-guide/latest/jvm_udf_dispatch.md @@ -0,0 +1,53 @@ + + +# ScalaUDF codegen dispatch + +Comet can route Spark `ScalaUDF` expressions through a JVM-side kernel that processes Arrow batches directly, instead of falling back to Spark for the whole operator. The kernel is compiled per `(expression, input schema)` pair via Janino and reused across batches of the same query. Surrounding native operators stay on the Comet path. The cost is one JNI roundtrip per batch. + +## Configuration + +| Key | Default | Description | +| ------------------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | +| `spark.comet.exec.scalaUDF.codegen.enabled` | `true` | When `true`, eligible `ScalaUDF`s route through the dispatcher. When `false`, plans containing a `ScalaUDF` fall back to Spark for that operator. | + +## Supported + +- User functions registered via `udf(...)`, `spark.udf.register(...)` (Scala or Java functional interfaces), or SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`. +- Scalar input and output types: `Boolean`, `Byte`, `Short`, `Int`, `Long`, `Float`, `Double`, `Decimal`, `String`, `Binary`, `Date`, `Timestamp`, `TimestampNTZ`. +- Complex input and output types with arbitrary nesting: `ArrayType`, `StructType`, `MapType`. +- Composition with other Catalyst expressions inside the user function's argument tree (e.g. `myUdf(upper(s))` binds the whole tree and compiles into one kernel). + +## Not supported + +- Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, the legacy `UserDefinedAggregateFunction`). +- Table UDFs and generators. +- Python `@udf` and Pandas `@pandas_udf`. +- Hive `GenericUDF` and `SimpleUDF`. +- `CalendarIntervalType` arguments and return types. + +## Behavior + +- Non-deterministic expressions referenced from the UDF's argument tree (`rand`, `uuid`, `monotonically_increasing_id`) produce per-partition sequences consistent with Spark. The kernel instance lives for one Spark task; state resets at task boundaries. +- `TaskContext.get()` inside the user function returns the driving Spark task's context even though the kernel runs on a Tokio worker thread. +- The user function must be closure-serializable. The same function that works with Spark's executor execution works here. + +## Known limitations + +- Each query analysis recompiles the kernel once. Spark's analyzer produces a fresh `ScalaUDF` instance per query, and the encoders embedded in that instance carry attribute references with fresh ids that the cache key cannot canonicalize across queries. Within one query, multiple batches of the same shape reuse the compiled kernel. diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index f5b04cc51d..ecb05eb91f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -462,8 +462,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( }; // Capture the driving Spark task's TaskContext as a JNI global reference when - // non-null. The `Arc>` releases its global ref on drop, so cleanup - // is automatic when the ExecutionContext drops. + // non-null. The `Arc>` releases its global ref on drop, so + // cleanup is automatic when the ExecutionContext drops. let task_context = if !task_context_obj.is_null() { Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?)) } else { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b00f140026..722ed50e1e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -183,8 +183,11 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, - /// Captured at `createPlan` time on `ExecutionContext`; see that struct for the - /// propagation rationale. `None` when no driving Spark task is available. + /// Spark `TaskContext` captured on the driving Spark task thread and stashed on the + /// [`ExecutionContext`] at `createPlan` time. Threaded into every [`JvmScalarUdfExpr`] the + /// planner builds so the JNI bridge can install it as the thread-local `TaskContext` on + /// the Tokio worker that drives the UDF. `None` when no driving Spark task is available + /// (unit tests, direct native driver runs). task_context: Option>>>, } @@ -205,20 +208,27 @@ impl PhysicalPlanner { } } - pub fn with_exec_id(mut self, exec_context_id: i64) -> Self { - self.exec_context_id = exec_context_id; - self + pub fn with_exec_id(self, exec_context_id: i64) -> Self { + Self { + exec_context_id, + partition: self.partition, + session_ctx: Arc::clone(&self.session_ctx), + query_context_registry: Arc::clone(&self.query_context_registry), + task_context: self.task_context, + } } - /// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned - /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as - /// the thread-local on the Tokio worker driving the UDF. - pub fn with_task_context( - mut self, - task_context: Option>>>, - ) -> Self { - self.task_context = task_context; - self + /// Attach a propagated Spark `TaskContext` global reference. Called by the JNI `executePlan` + /// entry with whatever was captured at `createPlan` time. The planner clones this `Option` + /// into every `JvmScalarUdfExpr` it builds. + pub fn with_task_context(self, task_context: Option>>>) -> Self { + Self { + exec_context_id: self.exec_context_id, + partition: self.partition, + session_ctx: self.session_ctx, + query_context_registry: self.query_context_registry, + task_context, + } } /// Return session context of this planner. @@ -741,6 +751,13 @@ impl PhysicalPlanner { to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { GeneralError("JvmScalarUdf missing return_type".to_string()) })?); + // Invariant: task_context is propagated for every JvmScalarUdfExpr built during + // normal execution. The TEST_EXEC_CONTEXT_ID path is the only context in which + // task_context may legitimately be None (unit tests, direct native driver runs). + debug_assert!( + self.task_context.is_some() || self.exec_context_id == TEST_EXEC_CONTEXT_ID, + "task_context must be set for non-test execution" + ); Ok(Arc::new(JvmScalarUdfExpr::new( udf.class_name.clone(), args, diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index d72323c961..f95d3cc174 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -231,7 +231,8 @@ pub struct JVMClasses<'a> { /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, /// The CometUdfBridge class used to dispatch JVM scalar UDFs. - /// `None` if the class is not on the classpath. + /// `None` if the class is not on the classpath; the JVM-UDF dispatch path + /// reports a clear error rather than crashing executor init. pub comet_udf_bridge: Option>, } @@ -304,6 +305,9 @@ impl JVMClasses<'_> { comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { + // Optional: if the bridge class is absent (e.g. comet shading + // dropped org.apache.comet.udf.*), record None and clear the + // pending JVM exception so other JNI calls keep working. let bridge = CometUdfBridge::new(env).ok(); if env.exception_check() { env.exception_clear(); diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 4ed25de6ee..0c6f9672ae 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -59,6 +59,10 @@ impl JvmScalarUdfExpr { return_nullable: bool, task_context: Option>>>, ) -> Self { + debug_assert!( + !class_name.is_empty(), + "JvmScalarUdfExpr requires a non-empty class name" + ); Self { class_name, args, @@ -120,10 +124,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { } fn evaluate(&self, batch: &RecordBatch) -> DFResult { - // Step 1: evaluate child expressions to get Arrow arrays. Scalar children - // (e.g. literal patterns) are sent as length-1 vectors rather than expanded - // to batch-row count, so the JVM bridge does not pay an O(rows) copy for - // values that never vary across the batch. + // Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than + // expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. The JVM side gets `numRows` directly via + // the bridge so it doesn't need the scalar to carry batch length. let arrays: Vec = self .args .iter() @@ -133,7 +137,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { }) .collect::>()?; - // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. let in_ffi_arrays: Vec> = arrays .iter() @@ -157,7 +160,13 @@ impl PhysicalExpr for JvmScalarUdfExpr { .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) .collect(); - // Allocate output FFI slots. + debug_assert!(!self.class_name.is_empty(), "class_name must not be empty"); + debug_assert_eq!( + in_arr_ptrs.len(), + in_sch_ptrs.len(), + "input array and schema pointer counts must match" + ); + let mut out_array = Box::new(FFI_ArrowArray::empty()); let mut out_schema = Box::new(FFI_ArrowSchema::empty()); let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; @@ -166,22 +175,20 @@ impl PhysicalExpr for JvmScalarUdfExpr { let class_name = self.class_name.clone(); let n_args = arrays.len(); - // Step 3: attach a JNI env for this thread and call the static bridge method. JVMClasses::with_env(|env| { let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { CometError::from(ExecutionError::GeneralError( "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ - class was not found on the JVM classpath." + class was not found on the JVM classpath. Set \ + spark.comet.exec.scalaUDF.codegen.enabled=false to disable this path." .to_string(), )) })?; - // Build the JVM String for the class name. let jclass_name = env .new_string(&class_name) .map_err(|e| CometError::JNI { source: e })?; - // Build the long[] arrays for input pointers. let in_arr_java = env .new_long_array(n_args) .map_err(|e| CometError::JNI { source: e })?; @@ -196,9 +203,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { .set_region(env, 0, &in_sch_ptrs) .map_err(|e| CometError::JNI { source: e })?; - // Pass a null jobject when no TaskContext was propagated so the bridge's null-guard - // leaves the worker thread's current TaskContext.get() in place. The borrow must - // outlive `call_static_method_unchecked`. + // Resolve the TaskContext reference once before building the arg array so the + // borrow lives until `call_static_method_unchecked` returns. When no TaskContext + // was propagated, pass a null object so the bridge's null-guard leaves the thread- + // local alone. let null_task_context = JObject::null(); let task_context_ref: &JObject = match &self.task_context { Some(gref) => gref.as_obj(), @@ -229,7 +237,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { Ok(()) })?; - // Step 4: import the result from the FFI slots filled by the JVM. // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap // allocation is freed by the move), and `from_ffi` wraps it in an Arc that // keeps the JVM-installed release callback alive until the resulting @@ -237,7 +244,19 @@ impl PhysicalExpr for JvmScalarUdfExpr { // exactly once when the Box drops at end of scope. let result_data = unsafe { from_ffi(*out_array, &out_schema) } .map_err(|e| CometError::Arrow { source: e })?; - Ok(ColumnarValue::Array(make_array(result_data))) + let result_array = make_array(result_data); + + // The JVM may produce arrays with different field names (e.g. Arrow Java's + // ListVector uses "$data$" for child fields) than what DataFusion expects + // (e.g. "item"). Cast to the declared return_type to normalize schema. + let result_array = if result_array.data_type() != &self.return_type { + arrow::compute::cast(&result_array, &self.return_type) + .map_err(|e| CometError::Arrow { source: e })? + } else { + result_array + }; + + Ok(ColumnarValue::Array(result_array)) } fn children(&self) -> Vec<&Arc> { diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala new file mode 100644 index 0000000000..3acdcbcf4b --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF} +import org.apache.spark.sql.types.BinaryType + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.codegen.CometBatchKernelCodegen +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * Routes scalar `ScalaUDF` expressions (user-registered Scala and Java UDFs) through the + * Arrow-direct codegen dispatcher. `ScalaUDF.doGenCode` emits compilable Java that invokes the + * user function via `ctx.addReferenceObj`, so the codegen path picks it up unchanged: we + * serialize the bound tree, the closure serializer carries the function reference across the + * wire, and the Janino-compiled kernel loads the function and invokes it in a tight batch loop. + * + * Not covered here: + * - Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, old UDAF API) - different + * bridge contract. + * - Table UDFs (`UserDefinedTableFunction`) - generator shape; `canHandle` rejects. + * - Python / Pandas UDFs - different runtime. + * - Hive UDFs (`HiveGenericUDF` / `HiveSimpleUDF`) - separate expression classes; would need + * their own serde. + * + * Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, the plan falls back to + * Spark for the enclosing operator; `ScalaUDF` has no native path so there is no in-between + * option. + */ +object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { + + override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { + withInfo( + expr, + s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " + + "so the plan falls back to Spark") + return None + } + + // Bind the tree against the set of AttributeReferences it actually reads, so the compiled + // kernel's Spark-codegen path resolves ordinals relative to the data args we send as inputs + // rather than the full input schema. + val attrs = expr.collect { case a: AttributeReference => a }.distinct + val boundExpr = BindReferences.bindReference(expr, AttributeSeq(attrs)) + + // Gate on canHandle before serializing: prevents unsupported input / output shapes from + // reaching the Janino compiler at execute time and surfaces the reason via withInfo. + CometBatchKernelCodegen.canHandle(boundExpr) match { + case Some(reason) => + withInfo(expr, reason) + return None + case None => + } + + // Serialize the bound tree via Spark's closure serializer. The serializer respects the task + // context classloader (so user UDF jars are visible) and matches the machinery Spark uses to + // ship closures across the wire. The bytes become arg 0 of the JvmScalarUdf proto; the + // dispatcher identifies the expression to compile from them, which makes the path work in + // cluster mode without executor-side driver registry state. + val serializer = SparkEnv.get.closureSerializer.newInstance() + val buffer = serializer.serialize(boundExpr) + val bytes = new Array[Byte](buffer.remaining()) + buffer.get(bytes) + val exprArg = exprToProtoInternal(Literal(bytes, BinaryType), inputs, binding) + .getOrElse(return None) + + val dataArgs = + attrs.map(a => exprToProtoInternal(a, inputs, binding).getOrElse(return None)) + val returnTypeProto = serializeDataType(expr.dataType).getOrElse(return None) + + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(classOf[CometScalaUDFCodegen].getName) + .addArgs(exprArg) + dataArgs.foreach(udfBuilder.addArgs) + udfBuilder + .setReturnType(returnTypeProto) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index dced2e4da8..620ff3974e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -258,6 +258,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[MakeDecimal] -> CometMakeDecimal, classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, classOf[ScalarSubquery] -> CometScalarSubquery, + classOf[ScalaUDF] -> CometScalaUDF, classOf[SparkPartitionID] -> CometSparkPartitionId, classOf[SortOrder] -> CometSortOrder, classOf[StaticInvoke] -> CometStaticInvoke, diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 63936a94b7..5a6d764ff6 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -236,7 +236,12 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("ArrayInsertUnsupportedArgs") { // This test checks that the else branch in ArrayInsert // mapping to the comet is valid and fallback to spark is working fine. - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArrayInsert]) -> "true") { + // Disable the codegen dispatcher so the `idx` ScalaUDF child returns None from its serde, + // which is what drives ArrayInsert's "unsupported arguments" branch. With the dispatcher + // enabled, ScalaUDF routes through codegen and the whole plan runs native. + withSQLConf( + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false", + CometConf.getExprAllowIncompatConfigKey(classOf[ArrayInsert]) -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, 10000) @@ -247,7 +252,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), - Set("scalaudf is not supported", "unsupported arguments for ArrayInsert")) + Set("ScalaUDF has no native path", "unsupported arguments for ArrayInsert")) } } } diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala new file mode 100644 index 0000000000..1bcc6117b3 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.io.File +import java.text.SimpleDateFormat + +import scala.util.Random + +import org.apache.commons.io.FileUtils +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.DataTypeSupport.isComplexType +import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions} +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * Randomized tests for the Arrow-direct codegen dispatcher: schema-driven coverage of every input + * vector class, plus a decimal precision-scale sweep across the `Decimal.MAX_LONG_DIGITS=18` + * boundary at varying null densities. Extends [[CometTestBase]] (not [[CometFuzzTestBase]]) + * because the base's `shuffle` x `nativeC2R` cross-product `test()` override is irrelevant for + * projection-only queries. + */ +class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + /** Random schema with primitives plus shallow arrays and structs. No maps, no deep nesting. */ + private var mixedTypesFilename: String = _ + + /** Random schema with deeply nested arrays / structs / maps. */ + private var nestedTypesFilename: String = _ + + /** Asia/Kathmandu has a non-zero minute offset (UTC+5:45); good for timezone edge cases. */ + private val defaultTimezone = "Asia/Kathmandu" + + override def beforeAll(): Unit = { + super.beforeAll() + val tempDir = System.getProperty("java.io.tmpdir") + val random = new Random(42) + val dataGenOptions = DataGenOptions( + generateNegativeZero = false, + baseDate = new SimpleDateFormat("YYYY-MM-DD hh:mm:ss") + .parse("2024-05-25 12:34:56") + .getTime) + + mixedTypesFilename = + s"$tempDir/CometCodegenDispatchFuzzSuite_${System.currentTimeMillis()}.parquet" + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { + val schemaGenOptions = + SchemaGenOptions(generateArray = true, generateStruct = true) + ParquetGenerator.makeParquetFile( + random, + spark, + mixedTypesFilename, + 1000, + schemaGenOptions, + dataGenOptions) + } + + nestedTypesFilename = + s"$tempDir/CometCodegenDispatchFuzzSuite_nested_${System.currentTimeMillis()}.parquet" + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { + val schemaGenOptions = + SchemaGenOptions(generateArray = true, generateStruct = true, generateMap = true) + val schema = FuzzDataGenerator.generateNestedSchema( + random, + numCols = 10, + minDepth = 2, + maxDepth = 4, + options = schemaGenOptions) + ParquetGenerator.makeParquetFile( + random, + spark, + nestedTypesFilename, + schema, + 1000, + dataGenOptions) + } + + spark.read.parquet(mixedTypesFilename).createOrReplaceTempView("t1") + spark.read.parquet(nestedTypesFilename).createOrReplaceTempView("t2") + } + + protected override def afterAll(): Unit = { + super.afterAll() + FileUtils.deleteDirectory(new File(mixedTypesFilename)) + FileUtils.deleteDirectory(new File(nestedTypesFilename)) + } + + private val RowCount: Int = 512 + private val nullDensities: Seq[Double] = Seq(0.0, 0.1, 0.5, 1.0) + // (precision, scale) shapes spanning both sides of `Decimal.MAX_LONG_DIGITS=18`: small short, + // boundary short with varying scale, just-past-boundary long, and max decimal128. + private val decimalShapes: Seq[(Int, Int)] = Seq((9, 2), (18, 0), (18, 9), (19, 0), (38, 10)) + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key, "true") + + /** + * Resets dispatcher stats, runs `f`, then asserts the codegen path actually ran for at least + * one batch. Without this, a silent serde fallback would let the fuzz pass trivially because + * both Spark and whatever-Comet-ran-instead agree with Spark. + */ + private def assertCodegenRan(f: => Unit): Unit = { + CometScalaUDFCodegen.resetStats() + f + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected at least one codegen dispatcher invocation during this query, got $after") + } + + /** + * Identity ScalaUDF for one of the 14 primitive types in + * [[org.apache.comet.testing.SchemaGenOptions.defaultPrimitiveTypes]]. Returns the registered + * name when the type maps to a known Scala arg, or `None` for shapes we choose not to probe. + * `BigDecimal` UDF args are encoded as `DecimalType(38, 18)`; Spark inserts an implicit cast + * around the call but the underlying column read still hits our kernel's `getDecimal` at the + * column's native precision. + */ + private def registerIdentityUdfFor(dt: DataType, name: String): Option[String] = dt match { + case _: BooleanType => spark.udf.register(name, (x: Boolean) => x); Some(name) + case _: ByteType => spark.udf.register(name, (x: Byte) => x); Some(name) + case _: ShortType => spark.udf.register(name, (x: Short) => x); Some(name) + case _: IntegerType => spark.udf.register(name, (x: Int) => x); Some(name) + case _: LongType => spark.udf.register(name, (x: Long) => x); Some(name) + case _: FloatType => spark.udf.register(name, (x: Float) => x); Some(name) + case _: DoubleType => spark.udf.register(name, (x: Double) => x); Some(name) + case _: DecimalType => + spark.udf.register(name, (x: java.math.BigDecimal) => x); Some(name) + case _: DateType => spark.udf.register(name, (x: java.sql.Date) => x); Some(name) + case _: TimestampType => + spark.udf.register(name, (x: java.sql.Timestamp) => x); Some(name) + case _: TimestampNTZType => + spark.udf.register(name, (x: java.time.LocalDateTime) => x); Some(name) + case _: StringType => spark.udf.register(name, (x: String) => x); Some(name) + case _: BinaryType => spark.udf.register(name, (x: Array[Byte]) => x); Some(name) + case _ => None + } + + /** + * Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map + * column, regardless of element type. + * + * Avoids `Seq[T]` / `Map[K, V]` UDF arg materialization: Spark's `MapObjects.doGenCode` reads + * each element unconditionally and null-checks afterward, so on null positions of a + * dictionary-encoded primitive Arrow vector the garbage ID buffer feeds + * `dictionary.decodeToLong/decodeToFloat` and throws `ArrayIndexOutOfBoundsException`. Bug + * reproduces in pure Spark; `cardinality(col)` exercises `getArray`/`getMap` without entering + * the element deserializer. + */ + private lazy val cardinalityProbeUdf: String = { + val name = "sz_complex" + spark.udf.register(name, (i: Int) => i) + name + } + + test("identity ScalaUDF over every primitive column") { + val primitiveFields = + spark.table("t1").schema.fields.filterNot(f => isComplexType(f.dataType)) + assert(primitiveFields.nonEmpty, "expected at least one primitive column in random schema") + for (field <- primitiveFields) { + val udfName = s"id_${field.name}" + registerIdentityUdfFor(field.dataType, udfName) match { + case Some(_) => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName(${field.name}) FROM t1") + } + case None => + fail( + s"primitive column ${field.name}: ${field.dataType} not in identity UDF catalog; " + + "extend registerIdentityUdfFor") + } + } + } + + test("complex-probe ScalaUDF on every complex column") { + val complexFields = spark.table("t1").schema.fields.filter(f => isComplexType(f.dataType)) + assert(complexFields.nonEmpty, "expected at least one complex column in random schema") + for (field <- complexFields) { + probeComplexColumn(field, viewName = "t1") + } + } + + test("complex-probe ScalaUDF on top-level columns of deeply nested schema") { + for (field <- spark.table("t2").schema.fields) { + probeComplexColumn(field, viewName = "t2") + } + } + + /** + * Element-level fuzz for nested array reads: `ArrayMax.doGenCode` walks every element of every + * row, calling the kernel's nested element getter — the path the unsafe-getter optimization + * touches and which the cardinality probe deliberately skips. + */ + test("array_max element fuzz: every Array column") { + val arrayPrimitiveFields = spark.table("t1").schema.fields.filter { + case StructField(_, ArrayType(elemDt, _), _, _) if !isComplexType(elemDt) => true + case _ => false + } + assert( + arrayPrimitiveFields.nonEmpty, + "expected at least one Array column in random schema") + for (field <- arrayPrimitiveFields) { + val ArrayType(elemDt, _) = field.dataType: @unchecked + val udfName = s"id_arrmax_${field.name}" + registerIdentityUdfFor(elemDt, udfName) match { + case Some(_) => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName(array_max(${field.name})) FROM t1") + } + case None => + fail( + s"array column ${field.name} elem ${elemDt} not in identity UDF catalog; " + + "extend registerIdentityUdfFor") + } + } + } + + /** + * Map variant of the array element fuzz: `map_keys` / `map_values` produce arrays the kernel + * walks via `ArrayMax`, exercising the map's per-row offset chain (MapVector -> entries + * StructVector -> child) that the array test alone wouldn't catch. + */ + test("array_max element fuzz: map_keys / map_values on Map columns") { + val mapPrimitiveFields = spark.table("t2").schema.fields.filter { + case StructField(_, MapType(kDt, vDt, _), _, _) + if !isComplexType(kDt) && !isComplexType(vDt) => + true + case _ => false + } + for (field <- mapPrimitiveFields) { + val MapType(kDt, vDt, _) = field.dataType: @unchecked + registerIdentityUdfFor(kDt, s"id_mapk_${field.name}").foreach { udf => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udf(array_max(map_keys(${field.name}))) FROM t2") + } + } + registerIdentityUdfFor(vDt, s"id_mapv_${field.name}").foreach { udf => + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $udf(array_max(map_values(${field.name}))) FROM t2") + } + } + } + } + + /** + * Doubly-nested array element fuzz: `flatten(arr)` collapses `Array>` into `Array` + * (exercising the outer-array element getter that returns each inner ArrayData), then + * `array_max` walks the leaf X primitives. Closes the gap that the singly-nested + * `array_max(arr)` test alone leaves on doubly-nested primitive arrays. + */ + test("array_max element fuzz: flatten on Array> columns") { + val nestedArrayPrimitiveFields = spark.table("t2").schema.fields.filter { + case StructField(_, ArrayType(ArrayType(elemDt, _), _), _, _) if !isComplexType(elemDt) => + true + case _ => false + } + for (field <- nestedArrayPrimitiveFields) { + val ArrayType(ArrayType(elemDt, _), _) = field.dataType: @unchecked + val udfName = s"id_arrflat_${field.name}" + registerIdentityUdfFor(elemDt, udfName).foreach { _ => + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $udfName(array_max(flatten(${field.name}))) FROM t2") + } + } + } + } + + /** + * Element-level fuzz for `Array>`. `array_distinct` is a non-HOF unary expression + * that hashes each element to dedupe; struct hashing is field-wise, so the kernel emits element + * reads on each struct's fields. (Tried `array_sort` first; it's a `HigherOrderFunction` whose + * `CodegenFallback` mark trips the dispatcher's reject — the lambda gap documented on + * `CometBatchKernelCodegen.canHandle`.) `cardinality` consumes without materialization. Asserts + * the optimizer keeps `ArrayDistinct` so the coverage isn't vacuously folded. + */ + test("array_distinct element fuzz: Array> columns") { + val arrayStructFields = spark.table("t1").schema.fields.filter { + case StructField(_, ArrayType(_: StructType, _), _, _) => true + case _ => false + } + spark.udf.register("id_int_arrdistinct", (i: Int) => i) + for (field <- arrayStructFields) { + val q = s"SELECT id_int_arrdistinct(cardinality(array_distinct(${field.name}))) FROM t1" + val df = sql(q) + val plan = df.queryExecution.optimizedPlan.toString + val planLower = plan.toLowerCase + assert( + planLower.contains("array_distinct") || planLower.contains("arraydistinct"), + s"optimizer eliminated array_distinct on column ${field.name}; coverage would be " + + s"vacuous. plan=\n$plan") + assertCodegenRan { + checkSparkAnswerAndOperator(df) + } + } + } + + private def probeCardinality(accessor: String, viewName: String): Unit = { + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName") + } + } + + /** + * Top-level Array / Map → cardinality probe. Struct → drill into each scalar child via + * `GetStructField`; nested Array / Map sub-fields also get the cardinality probe (depth bound: + * deeper struct-of-struct nesting is skipped to keep the sweep finite). + */ + private def probeComplexColumn(field: StructField, viewName: String): Unit = { + field.dataType match { + case _: ArrayType | _: MapType => + probeCardinality(field.name, viewName) + + case st: StructType => + for (subField <- st.fields) { + val accessor = s"${field.name}.${subField.name}" + subField.dataType match { + case _: ArrayType | _: MapType => probeCardinality(accessor, viewName) + case dt if !isComplexType(dt) => + val udfName = s"id_${field.name}_${subField.name}" + registerIdentityUdfFor(dt, udfName).foreach { _ => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName($accessor) FROM $viewName") + } + } + case _ => // deeper struct nesting skipped + } + } + + case _ => + } + } + + /** Random `BigDecimal` values fitting `(precision, scale)`, with `nullDensity` of them null. */ + private def generateDecimals( + seed: Long, + precision: Int, + scale: Int, + nullDensity: Double): Seq[java.math.BigDecimal] = { + val rng = new Random(seed) + val intDigits = precision - scale + // `BigInt.apply(bits, rng)` samples uniformly on `[0, 2^bits - 1]`; bound to the decimal's + // integer-part range (10^intDigits - 1) so the result fits the schema. `BigInteger.bitLength` + // would overshoot slightly; min with the exact max is cheap insurance. + val intMax = BigInt(10).pow(intDigits) - 1 + val bits = math.max(intMax.bitLength, 1) + (0 until RowCount).map { _ => + if (rng.nextDouble() < nullDensity) null + else { + val mag = BigInt(bits, rng).min(intMax) + val signed = if (rng.nextBoolean()) -mag else mag + new java.math.BigDecimal(signed.bigInteger, scale) + } + } + } + + private def withDecimalTable(decimalType: String, values: Seq[java.math.BigDecimal])( + f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + if (values.nonEmpty) { + val rows = values.map { v => + if (v == null) "(NULL)" else s"(${v.toPlainString})" + } + rows.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.mkString(", ")}") + } + } + f + } + } + + for { + density <- nullDensities + (precision, scale) <- decimalShapes + } { + test(s"decimal identity precision=$precision scale=$scale nullDensity=$density") { + spark.udf.register("dec_id_fuzz", (d: java.math.BigDecimal) => d) + val seed = ((precision * 31L) + scale) * 31L + density.hashCode + val values = generateDecimals(seed, precision, scale, density) + withDecimalTable(s"DECIMAL($precision, $scale)", values) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT dec_id_fuzz(d) FROM t")) + } + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala new file mode 100644 index 0000000000..d8113549a1 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala @@ -0,0 +1,1243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.arrow.vector._ +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.api.java.UDF1 +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types._ + +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * Smoke tests for the Arrow-direct codegen dispatcher. Runs ScalaUDF queries across the scalar + * and complex type surface, composed UDF trees, subquery reuse, `TaskContext` propagation, and + * per-task cache isolation, asserting results match Spark. + */ +class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key, "true") + + private def withSubjects(values: String*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val rows = values + .map(v => if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + /** + * Composition smoke tests. Demonstrate that the codegen dispatcher handles nested expression + * trees in one compile per (tree, schema) pair, not one JNI hop per sub-expression. Each test + * wraps the query in `assertCodegenDidWork` to prove the codegen path ran rather than silently + * falling back to Spark. + */ + private def assertCodegenDidWork(f: => Unit): Unit = { + CometScalaUDFCodegen.resetStats() + f + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Stronger form of [[assertCodegenDidWork]]: asserts the full expression subtree compiled into + * at most one kernel. A "one JNI crossing per nesting level" implementation would produce one + * cache entry per sub-expression and `compileCount` of N. `<=` rather than `==` because the + * cache is JVM-wide; a prior test may have produced a hit (compileCount==0). The activity check + * guards against silent Spark fallback where the first two asserts pass vacuously. + */ + private def assertOneKernelForSubtree(f: => Unit): Unit = { + CometScalaUDFCodegen.resetStats() + val sizeBefore = CometScalaUDFCodegen.stats().cacheSize + f + val after = CometScalaUDFCodegen.stats() + assert(after.compileCount <= 1, s"expected <= 1 compile for the composed subtree, got $after") + val grew = after.cacheSize - sizeBefore + assert(grew <= 1, s"expected cache to grow by <= 1 entry, grew by $grew; stats=$after") + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Assert the compile cache contains a kernel matching the given input Arrow vector classes (in + * ordinal order) and output `DataType`. A specialization check: if a future change loses + * vector-class discrimination on the cache key, `checkSparkAnswerAndOperator` still passes + * (Spark answers correctly) but this assertion fails. Cache is JVM-wide so a prior test's + * compile counts; pair with `assertCodegenDidWork` to also prove this test ran the dispatcher. + * + * Compares by simple name because `common` shades `org.apache.arrow`; a direct + * `classOf[VarCharVector]` here (unshaded) wouldn't match the shaded class the dispatcher + * actually stores. + */ + private def assertKernelSignaturePresent( + inputs: Seq[Class[_ <: ValueVector]], + output: DataType): Unit = { + val sigs = CometScalaUDFCodegen.snapshotCompiledSignatures() + val expectedNames = inputs.map(_.getSimpleName).toIndexedSeq + val present = sigs.exists { case (cached, dt) => + dt == output && cached.map(_.getSimpleName) == expectedNames + } + assert( + present, + s"expected kernel signature $expectedNames -> $output; " + + s"cache had ${sigs.map { case (c, d) => (c.map(_.getSimpleName), d) }}") + } + + /** + * Multi-column smoke tests. The dispatcher compiles the whole bound expression tree, including + * composed sub-expressions that reference multiple columns. Verify end-to-end correctness + * against Spark for a handful of representative shapes. + */ + private def withTwoStringCols(rows: (String, String)*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (c1 STRING, c2 STRING) USING parquet") + if (rows.nonEmpty) { + val tuples = rows.map { case (a, b) => + val av = if (a == null) "NULL" else s"'${a.replace("'", "''")}'" + val bv = if (b == null) "NULL" else s"'${b.replace("'", "''")}'" + s"($av, $bv)" + } + sql(s"INSERT INTO t VALUES ${tuples.mkString(", ")}") + } + f + } + } + + test("ScalaUDF over concat(c1, c2) suppresses the null short-circuit") { + // Concat is not NullIntolerant. The dispatcher's short-circuit guard inspects every node in + // the bound tree and must skip the whole-tree null short-circuit because one child is + // non-NullIntolerant. The kernel therefore delegates null handling to Spark's generated + // code (which handles Concat(null, x) = x correctly) rather than returning null for any + // null input. Without the guard, null inputs would produce null outputs even where Spark + // produces a non-null concatenation. + spark.udf.register("tag", (s: String) => if (s == null) "N" else s"[${s}]") + withTwoStringCols(("abc", "123"), ("abc", null), (null, "123"), (null, null), ("zz", "zz")) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT tag(concat(c1, c2)) FROM t")) + } + } + } + + test("disabled mode bypasses the dispatcher") { + // When the per-feature config is off, `CometScalaUDF.convert` returns None and the enclosing + // operator falls back to Spark. The dispatcher's counters must not move. We do not assert + // `checkSparkAnswerAndOperator` here because ScalaUDF has no Comet-native path, so the + // project runs on the JVM Spark path under this configuration. + spark.udf.register("noopStr", (s: String) => s) + CometScalaUDFCodegen.resetStats() + withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { + withSubjects("disabled_1", null) { + checkSparkAnswer(sql("SELECT noopStr(s) FROM t")) + } + } + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount == 0 && after.cacheHitCount == 0, + s"expected no dispatcher activity under disabled config, got $after") + } + + test("per-batch nullability produces distinct compiles for null-present vs null-absent") { + // Same ScalaUDF + same Arrow vector class + different observed nullability should hit + // different cache keys, because `ArrowColumnSpec.nullable` flips when the batch has no + // nulls. We don't assert on per-run deltas because Spark's partitioning can split the + // subject table so the first query alone sees both nullability variants across different + // partitions. Instead, assert the total invariant: across both queries we see at least two + // compiles, proving the cache key discriminated on nullability. + spark.udf.register("nullabilityMarker", (s: String) => if (s == null) null else s + "!") + CometScalaUDFCodegen.resetStats() + + withSubjects("nullability_marker_1", null, "nullability_marker_2") { + checkSparkAnswerAndOperator(sql("SELECT nullabilityMarker(s) FROM t")) + } + withSubjects("nullability_marker_3", "nullability_marker_4") { + checkSparkAnswerAndOperator(sql("SELECT nullabilityMarker(s) FROM t")) + } + val after = CometScalaUDFCodegen.stats() + + assert( + after.compileCount >= 2, + "expected at least two compiles across both nullability distributions (one per " + + s"nullable=true/false variant); got $after") + } + + test("dispatcher caches the compiled kernel across batches of one query") { + // Within a single query, the dispatcher compiles a kernel for the (expression, schema) pair + // once and reuses it across every subsequent batch of the same shape. Force multiple batches + // by lowering the Comet batch size with a row count well above it, then assert at least one + // cache hit happened during the query. + // + // We deliberately do not assert cross-query cache reuse: Spark's analyzer produces a fresh + // `ScalaUDF` instance per query resolution, and the encoders embedded in that instance + // contain `AttributeReference`s with fresh `ExprId`s that our `BindReferences.bindReference` + // does not recurse into. The closure-serialized cache key bytes therefore drift across + // queries even when the registered function and schema are identical, so each new query of a + // ScalaUDF pays one compile up front and amortizes within itself. This is an acceptable + // amortization story (a few tens of milliseconds per query), not a behavior we can or do + // promise across queries. + spark.udf.register("kernelCacheMarker", (s: String) => if (s == null) null else s + "_kc") + val rows = (0 until 256).map(i => s"row_$i") + CometScalaUDFCodegen.resetStats() + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "32") { + withSubjects(rows: _*) { + checkSparkAnswerAndOperator(sql("SELECT kernelCacheMarker(s) FROM t")) + } + } + val stats = CometScalaUDFCodegen.stats() + assert(stats.compileCount >= 1, s"expected at least one compile during the query, got $stats") + assert( + stats.cacheHitCount >= 1, + s"expected at least one cache hit across batches of the same query, got $stats") + } + + test("per-partition kernel preserves Nondeterministic state across batches") { + // Wrap `monotonically_increasing_id()` as the argument of a ScalaUDF so the whole tree + // (including the stateful MonotonicallyIncreasingID child) routes through the dispatcher. + // Per-partition kernel caching means the id counter advances across batches within a + // partition; without it, every batch would restart at 0 and the UDF output would disagree + // with Spark's. The UDF body is a trivial identity; we're testing state correctness of the + // Nondeterministic child across batches, not the UDF logic. + spark.udf.register("idPassthrough", (id: Long) => id) + val rows = (0 until 4096).map(i => s"row_$i") + withSubjects(rows: _*) { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT s, idPassthrough(monotonically_increasing_id()) FROM t")) + } + } + } + + test("per-task cache isolates UDF state across sequential task runs in one session") { + // Regression guard for the cache-scoping invariant on CometUdfBridge: instances live for + // exactly one Spark task and are dropped on task completion, so a stateful kernel sees a + // fresh instance per task. Running the same `monotonically_increasing_id()`-carrying query + // twice in one session must produce identical results each run. Under a cache that outlived + // a task and got reused by the next one, the counter would continue from the previous run's + // final value and the second run's IDs would diverge. Under a cache that was keyed by Tokio + // worker thread rather than task attempt ID, worker reuse across tasks would cause the same + // leak whenever the second task happened to be polled by the same worker. + val rows = (0 until 2048).map(i => s"row_$i") + withSubjects(rows: _*) { + val q = "SELECT s, monotonically_increasing_id() AS mid FROM t" + val first = sql(q).collect().map(r => (r.getString(0), r.getLong(1))).toSeq + val second = sql(q).collect().map(r => (r.getString(0), r.getLong(1))).toSeq + assert( + first == second, + s"per-task cache leaked state across runs: first=${first.take(5)} second=${second.take(5)}") + } + } + + /** + * Scalar ScalaUDF smoke tests. These prove that user-registered UDFs route through the codegen + * dispatcher rather than forcing a whole-plan Spark fallback. Spark's `ScalaUDF.doGenCode` + * already emits compilable Java that calls the user function via `ctx.addReferenceObj`, so the + * dispatcher's compile path picks it up for free. Tests that user-registered UDFs route through + * the dispatcher rather than forcing whole-plan Spark fallback. + */ + + test("registered string ScalaUDF routes through dispatcher") { + spark.udf.register("shout", (s: String) => if (s == null) null else s.toUpperCase + "!") + withSubjects("Abc", "xyz", null, "mixed") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT shout(s) FROM t")) + } + } + } + + test("registered Java UDF1 routes through dispatcher") { + // Java API path: `spark.udf.register(name, UDF1<...>, returnType)`. Spark wraps the Java + // functional interface in a Scala function and produces a `ScalaUDF` expression at plan + // time, so the dispatcher handles it the same as a Scala-registered UDF. Sanity check that + // both registration paths land on the same routing code. + spark.udf.register( + "javaLen", + new UDF1[String, Integer] { + override def call(s: String): Integer = if (s == null) -1 else s.length + }, + IntegerType) + withSubjects("abc", "hello", null, "x") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT javaLen(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("multi-arg ScalaUDF over string + literal routes through dispatcher") { + spark.udf.register( + "prepend", + (prefix: String, s: String) => if (s == null) null else prefix + s) + withSubjects("one", "two", null) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT prepend('[', s) FROM t")) + } + } + } + + test("ScalaUDF as a child of a native Spark expression") { + // The ScalaUDF routes through the dispatcher as a sub-expression; the surrounding `length` + // runs through Comet's native scalar function path. This exercises the cross-boundary + // composition where a dispatcher-compiled kernel returns a UTF8String that a native Comet + // expression then consumes. + spark.udf.register("wrap", (s: String) => if (s == null) null else s"|$s|") + withSubjects("abc", "def", null) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT length(wrap(s)) FROM t")) + } + } + } + + test("composed ScalaUDFs outer(inner(s)) fuse into one kernel") { + // Two user UDFs stacked, both operating on String. The dispatcher binds the whole tree and + // Spark's codegen emits two `ctx.addReferenceObj` calls inside one generated method. Races + // on the `ExpressionEncoder` serializers in `references` would show up here since each UDF + // contributes its own stateful serializer; the `freshReferences` closure in `CompiledKernel` + // is what keeps this correct across partitions. + spark.udf.register("inner", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("outer", (s: String) => if (s == null) null else s"<$s>") + withSubjects("abc", null, "xyz", "MiXeD") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT outer(inner(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), StringType) + } + } + + test("ScalaUDFs of different types compose: isShort(len(s))") { + // Exercises an input type transition: String -> Int -> Boolean. Two user UDFs with + // different I/O type shapes in one tree, one Janino compile. + spark.udf.register("len", (s: String) => if (s == null) -1 else s.length) + spark.udf.register("isShort", (i: Int) => i < 5) + withSubjects("ab", "abcdef", null, "hi") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT isShort(len(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BooleanType) + } + } + + test("three-deep ScalaUDF composition lvl3(lvl2(lvl1(s)))") { + // Three user UDFs stacked in one tree: String -> String -> String -> Int. The fused kernel + // carries three `ctx.addReferenceObj` calls. `assertOneKernelForSubtree` asserts that the + // whole chain collapses into a single compile rather than one per nesting level. + // Input rows intentionally exclude nulls: per-batch nullability is a cache-key dimension + // (`nullable()` reads `getNullCount != 0`), so a null-present batch compiles a second kernel + // specialized for `nullable=true`. Null handling through composed UDFs is covered by the + // other composition tests above. + spark.udf.register("lvl1", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lvl2", (s: String) => if (s == null) null else s.reverse) + spark.udf.register("lvl3", (s: String) => if (s == null) -1 else s.length) + withSubjects("abc", "hello world", "x") { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT lvl3(lvl2(lvl1(s))) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("multi-column ScalaUDF composition join(upperU(c1), lowerU(c2))") { + // One multi-arg user UDF consuming two other user UDFs, each on a different input column. + // The bound tree has two BoundReferences, and the kernel is specialized on two VarCharVector + // columns. `assertOneKernelForSubtree` asserts that the two-branch composition fuses into a + // single kernel rather than one per branch or one per UDF. + // Input rows intentionally exclude nulls (see note on the three-deep test above). + spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase) + spark.udf.register( + "joinU", + (a: String, b: String) => if (a == null || b == null) null else s"$a-$b") + withTwoStringCols(("Abc", "XYZ"), ("Foo", "bar"), ("baz", "Bar"), ("Hi", "Lo")) { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT joinU(upperU(c1), lowerU(c2)) FROM t")) + } + assertKernelSignaturePresent( + Seq(classOf[VarCharVector], classOf[VarCharVector]), + StringType) + } + } + + /** + * Type-surface ScalaUDF tests. Each exercises a distinct Arrow input vector class plus the + * matching output writer end to end. + * + * Backed by parquet tables with declared column types rather than `spark.range` projections: + * derived `cast(id as int)` columns get folded into the plan and leave the `BoundReference` on + * the underlying long, not the projected int. A declared parquet column keeps the Arrow vector + * the dispatcher sees aligned with the UDF's signature. + */ + private def withTypedCol(sqlType: String, valueLiterals: String*)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (c $sqlType) USING parquet") + if (valueLiterals.nonEmpty) { + val rows = valueLiterals.map(v => s"($v)").mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + } + f + } + } + + test("ScalaUDF on IntegerType (IntVector, getInt)") { + spark.udf.register("doubleIt", (i: Int) => i * 2) + withTypedCol("INT", "1", "2", "100") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT doubleIt(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[IntVector]), IntegerType) + } + } + + test("ScalaUDF on LongType (BigIntVector, getLong)") { + spark.udf.register("inc", (l: Long) => l + 1L) + withTypedCol("BIGINT", "1", "2", "100") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT inc(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), LongType) + } + } + + test("ScalaUDF on DoubleType (Float8Vector, getDouble)") { + spark.udf.register("halve", (d: Double) => d / 2.0) + withTypedCol("DOUBLE", "1.5", "2.5", "100.0") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT halve(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[Float8Vector]), DoubleType) + } + } + + test("ScalaUDF on FloatType (Float4Vector, getFloat)") { + spark.udf.register("scaleF", (f: Float) => f * 1.5f) + withTypedCol("FLOAT", "CAST(1.5 AS FLOAT)", "CAST(2.5 AS FLOAT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT scaleF(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[Float4Vector]), FloatType) + } + } + + test("ScalaUDF on BooleanType (BitVector, getBoolean)") { + spark.udf.register("neg", (b: Boolean) => !b) + withTypedCol("BOOLEAN", "TRUE", "FALSE", "TRUE") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT neg(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BitVector]), BooleanType) + } + } + + test("ScalaUDF on ShortType (SmallIntVector, getShort)") { + spark.udf.register("incS", (s: Short) => (s + 1).toShort) + withTypedCol( + "SMALLINT", + "CAST(1 AS SMALLINT)", + "CAST(2 AS SMALLINT)", + "CAST(30000 AS SMALLINT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT incS(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[SmallIntVector]), ShortType) + } + } + + test("ScalaUDF on ByteType (TinyIntVector, getByte)") { + spark.udf.register("incB", (b: Byte) => (b + 1).toByte) + withTypedCol("TINYINT", "CAST(1 AS TINYINT)", "CAST(2 AS TINYINT)", "CAST(100 AS TINYINT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT incB(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TinyIntVector]), ByteType) + } + } + + test("ScalaUDF on DateType (DateDayVector, getInt)") { + // Date input flows through the Int getter because DateType is physically int. The UDF takes + // java.sql.Date and Spark's encoder handles the int -> Date materialization. + spark.udf.register( + "nextDay", + (d: java.sql.Date) => if (d == null) null else new java.sql.Date(d.getTime + 86400000L)) + withTypedCol("DATE", "DATE'2024-01-01'", "DATE'2024-06-15'", "DATE'1970-01-01'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT nextDay(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[DateDayVector]), DateType) + } + } + + test("ScalaUDF on TimestampType (TimeStampMicroTZVector, getLong)") { + spark.udf.register( + "plusSecond", + (t: java.sql.Timestamp) => + if (t == null) null else new java.sql.Timestamp(t.getTime + 1000L)) + withTypedCol( + "TIMESTAMP", + "TIMESTAMP'2024-01-01 12:00:00'", + "TIMESTAMP'2024-06-15 23:59:59'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT plusSecond(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TimeStampMicroTZVector]), TimestampType) + } + } + + test("ScalaUDF on TimestampNTZType (TimeStampMicroVector, getLong)") { + spark.udf.register( + "plusDayNtz", + (ldt: java.time.LocalDateTime) => if (ldt == null) null else ldt.plusDays(1)) + withTypedCol( + "TIMESTAMP_NTZ", + "TIMESTAMP_NTZ'2024-01-01 12:00:00'", + "TIMESTAMP_NTZ'2024-06-15 23:59:59'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT plusDayNtz(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TimeStampMicroVector]), TimestampNTZType) + } + } + + test("ScalaUDF returning DateType") { + spark.udf.register("epochDay", (_: Int) => java.sql.Date.valueOf("1970-01-01")) + withTypedCol("INT", "1", "2", "3") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT epochDay(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[IntVector]), DateType) + } + } + + test("ScalaUDF returning TimestampType") { + spark.udf.register("mkTs", (s: Long) => new java.sql.Timestamp(s * 1000L)) + withTypedCol("BIGINT", "0", "1700000000", "1750000000") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT mkTs(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), TimestampType) + } + } + + test("ScalaUDF returning TimestampNTZType") { + spark.udf.register( + "mkTsNtz", + (s: Long) => java.time.LocalDateTime.ofEpochSecond(s, 0, java.time.ZoneOffset.UTC)) + withTypedCol("BIGINT", "0", "1700000000", "1750000000") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT mkTsNtz(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), TimestampNTZType) + } + } + + test("ScalaUDF returning a different type than its input") { + // String -> Int transition forces the output writer to switch from VarChar to Int. Exercises + // the `IntegerType` output path end to end from a user UDF. + spark.udf.register("codePoint", (s: String) => if (s == null) 0 else s.codePointAt(0)) + withSubjects("abc", "A", null, "!") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT codePoint(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("ScalaUDF returning BinaryType (VarBinaryVector output writer)") { + // Binary output writer path, exercised here by a user UDF for the first time. Before this + // the writer only had direct-compile unit tests. + spark.udf.register("bytes", (s: String) => if (s == null) null else s.getBytes("UTF-8")) + withSubjects("abc", null, "hello") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT bytes(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BinaryType) + } + } + + test("ScalaUDF on BinaryType (VarBinaryVector, getBinary)") { + // Binary input getter path: VarBinaryVector with byte[] reads via Spark's `getBinary` getter. + spark.udf.register("blen", (b: Array[Byte]) => if (b == null) -1 else b.length) + withTable("t") { + sql("CREATE TABLE t (b BINARY) USING parquet") + sql("INSERT INTO t VALUES (CAST('abc' AS BINARY)), (CAST('hello' AS BINARY)), (NULL)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT blen(b) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarBinaryVector]), IntegerType) + } + } + + test("ScalaUDF returning ArrayType(StringType) (ListVector output writer)") { + // First use of the ArrayType output path end-to-end. The UDF returns a `Seq[String]`, + // which Spark encodes as `ArrayType(StringType, containsNull = true)`. The dispatcher's + // canHandle accepts it (ArrayType is supported when its element type is supported), + // allocateOutput builds a ListVector with an inner VarCharVector, and emitWrite recurses + // into the StringType case for the per-element UTF8 on-heap shortcut. End-to-end answer + // matches Spark. + spark.udf.register( + "splitComma", + (s: String) => if (s == null) null else s.split(",", -1).toSeq) + withSubjects("a,b,c", "x", null, "", "one,,three") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT splitComma(s) FROM t")) + } + } + } + + test("ScalaUDF returning ArrayType(IntegerType)") { + // Exercises ArrayType output with a primitive element. emitWrite's ArrayType case + // recurses into the IntegerType case for the inner write; no byte[] allocation involved. + spark.udf.register( + "asLengths", + (s: String) => if (s == null) null else s.split(",").map(_.length).toSeq) + withSubjects("a,bb,ccc", null, "xyzzy") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT asLengths(s) FROM t")) + } + } + } + + test("zero-column ScalaUDF produces one row per input row") { + // Non-deterministic (so Spark doesn't constant-fold) with a deterministic body (so + // Spark-vs-Comet comparison stays honest). The expression has no `AttributeReference`, + // so the serde produces an empty data-arg list and the dispatcher has no data column to + // read the batch size from. Guards the `numRows` path through the JNI bridge. + import org.apache.spark.sql.functions.udf + val alwaysHello = udf(() => "hello").asNondeterministic() + spark.udf.register("helloU", alwaysHello) + withSubjects("a", "b", null, "c") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT helloU() FROM t")) + } + } + } + + /** + * Decimal tests. The dispatcher's `getDecimal` getter specializes on the `BoundReference`'s + * `DecimalType.precision` at source-generation time: precision <= 18 emits an unscaled-long + * fast path via `Decimal.createUnsafe`, precision > 18 emits a `BigDecimal + Decimal.apply` + * slow path. These smoke tests exercise both sides of the split end to end and verify Spark and + * Comet agree on correctness across typical decimal workloads. + */ + private def withDecimalTable(decimalType: String, values: Seq[String])(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + val rows = values.map(v => if (v == null) "(NULL)" else s"($v)").mkString(", ") + if (values.nonEmpty) sql(s"INSERT INTO t VALUES $rows") + f + } + } + + test("ScalaUDF over Decimal(9, 2) (short precision, fast path)") { + // Short-precision identity UDF. The column's DecimalType has precision 9, so the generated + // getter for ordinal 0 emits only the unscaled-long fast path. The UDF's Scala-side signature + // uses `java.math.BigDecimal`, which Spark's encoder pins at DecimalType(38, 18); the implicit + // Cast from DECIMAL(9, 2) -> DECIMAL(38, 18) runs inside Spark's generated code, not via our + // kernel's getter, so the fast path still fires on the column read. + spark.udf.register("decId9_2", (d: java.math.BigDecimal) => d) + withDecimalTable("DECIMAL(9, 2)", Seq("0.00", "1.50", "-1.50", "9999.99", "-9999.99", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId9_2(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(18, 0) (max short precision, fast path)") { + // Boundary precision: 18 is the last value for which the unscaled representation fits in a + // signed 64-bit long. The fast path must still be selected. + spark.udf.register("decId18_0", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(18, 0)", + Seq("0", "1", "-1", "999999999999999999", "-999999999999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId18_0(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(18, 9) (max short precision with scale, fast path)") { + // Same precision as above but with scale 9 to exercise the fractional side of the long + // decimal. Spark `Decimal` stores both as the same unscaled long; only the `scale` parameter + // differs. + spark.udf.register("decId18_9", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(18, 9)", + Seq("0.000000000", "1.123456789", "-1.123456789", "999999999.999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId18_9(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(19, 0) (just past short precision, slow path)") { + // First precision where the unscaled value can exceed `Long.MAX_VALUE`. The generated getter + // must emit only the slow path; the fast-path marker must be absent in the compiled kernel. + spark.udf.register("decId19_0", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(19, 0)", + Seq("0", "1", "-1", "9999999999999999999", "-9999999999999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId19_0(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(38, 10) (max precision, slow path)") { + // Max decimal128 precision. Exercises the `getObject + Decimal.apply` branch and the + // end-to-end BigDecimal conversion path with a non-trivial scale. + spark.udf.register("decId38_10", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(38, 10)", + Seq( + "0.0000000000", + "1.1234567890", + "-1.1234567890", + "9999999999999999999999999999.0000000000", + null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId38_10(d) FROM t")) + } + } + } + + test("ScalaUDF sees TaskContext.partitionId() per partition") { + // Direct probe: register a ScalaUDF that reads TaskContext.partitionId() and returns it. + // Spark's own task thread has TaskContext set, so each partition's rows carry that + // partition's index. For the dispatcher to match Spark, the invocation thread must see a + // live TaskContext. With the `createPlan`-time TaskContext capture + bridge-side + // `TaskContext.setTaskContext` install (see `CometUdfBridge.evaluate` and + // `CometTaskContextShim`), Tokio workers see the propagated TaskContext and the UDF + // returns the real partitionId. Without that propagation, `TaskContext.get()` returns null + // on the Tokio thread and the sentinel (-1) leaks through, diverging from Spark. + spark.udf.register( + "pid", + (_: Long) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "pid(id) as p") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF sees TaskContext from fully-native parquet plan") { + // The `spark.range`-based test above runs through `CometSparkRowToColumnar`, which executes + // on a Spark task thread where TaskContext is live even without explicit propagation. The + // fully-native path through `CometNativeScan` runs the JVM UDF bridge on a Tokio worker + // thread where TaskContext.get() would otherwise be null. This test forces that path by + // sourcing from a Parquet table written as multiple files (so the native read produces + // multiple partitions) and asserting the UDF still sees the per-partition TaskContext via + // the `createPlan`-time capture + bridge-side install. + spark.udf.register( + "pidP", + (_: Int) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + // Multiple INSERT statements -> multiple parquet files -> multiple read splits -> + // multiple partitions. + sql("INSERT INTO t VALUES (1), (2), (3), (4)") + sql("INSERT INTO t VALUES (5), (6), (7), (8)") + sql("INSERT INTO t VALUES (9), (10), (11), (12)") + sql("INSERT INTO t VALUES (13), (14), (15), (16)") + checkSparkAnswerAndOperator(sql("SELECT x, pidP(x) AS p FROM t")) + } + } + + test("Rand seeded per partition across a multi-partition table") { + // Rand.doGenCode registers an XORShiftRandom via ctx.addMutableState and seeds it via + // ctx.addPartitionInitializationStatement. That init statement runs inside our kernel's + // `init(int partitionIndex)`, called once per kernel allocation. Spark seeds + // `XORShiftRandom(seed + partitionIndex)` per partition, so different partitions produce + // different sequences for the same seed. Matching Spark across partitions requires the + // kernel to see the real partition index, which the dispatcher derives from + // `TaskContext.get().partitionId()` — live on this path thanks to the bridge-level + // TaskContext propagation. Composing with a ScalaUDF (identity on Double here) forces the + // tree through codegen dispatch so the Rand evaluation runs inside our kernel's init + // rather than via Spark's normal codegen. + spark.udf.register("dblId", (d: Double) => d) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "dblId(rand(42)) as r") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF composed with reused scalar subquery across projection and filter") { + // The same scalar subquery appears in two sites: the projection (which the dispatcher + // compiles into a fused kernel) and the filter (a separate operator). Each site holds its + // own `ScalarSubquery` expression instance with its own `@volatile result` field. Each + // surrounding operator's inherited `SparkPlan.waitForSubqueries` populates its instance's + // `result` before the dispatcher's bridge serializes the expression. The populated value + // travels through closure serialization into the cache key's bytes, so different subquery + // values compile distinct kernels. Exercises the full subquery-correctness invariant + // documented on `CometBatchKernelCodegen.canHandle`. + spark.udf.register("addOne", (i: Int) => i + 1) + withTable("t", "t2") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (4), (5)") + sql("CREATE TABLE t2 (v INT) USING parquet") + sql("INSERT INTO t2 VALUES (2), (4)") + checkSparkAnswerAndOperator( + sql("SELECT addOne(x) + (SELECT max(v) FROM t2) AS r " + + "FROM t WHERE addOne(x) < (SELECT max(v) FROM t2) * 2")) + } + } + + /** + * ArrayType input. The dispatcher emits a nested `InputArray_col0` final class per array-typed + * input column; Spark's generated `getArray(ord)` resolves to our kernel's switch which returns + * the pre-allocated instance after resetting its start/length against the list's offsets. + * Element reads go through the typed child-vector field with no `ArrayData` copy or boxing. + * + * Each smoke test exercises the same serde/transport path at a different element type so the + * nested getter emitter's scalar-element cases are each covered: `StringType` (zero-copy + * `UTF8String.fromAddress`), `IntegerType` (primitive direct), and `DecimalType(p <= 18)` + * (decimal128 fast path). + */ + private def withArrayTable(colType: String, insertRows: String)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (a $colType) USING parquet") + sql(s"INSERT INTO t VALUES $insertRows") + f + } + } + + test("ScalaUDF taking Seq[String] reads through nested ArrayData class") { + spark.udf.register( + "headOrNull", + (arr: Seq[String]) => if (arr == null || arr.isEmpty) null else arr.head) + withArrayTable( + "ARRAY", + "(array('a', 'b', 'c')), (array('x')), (null), (array()), (array('alone'))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT headOrNull(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[String] iterating all elements") { + spark.udf.register( + "concatArr", + (arr: Seq[String]) => if (arr == null) null else arr.mkString("|")) + withArrayTable( + "ARRAY", + "(array('one', 'two', 'three')), (array('solo')), (null), (array())") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT concatArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[Int] hits primitive element getter") { + spark.udf.register("sumArr", (arr: Seq[Int]) => if (arr == null) -1 else arr.sum) + withArrayTable( + "ARRAY", + "(array(1, 2, 3)), (array(-5, 5)), (array()), (null), (array(42))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[BigDecimal] hits short-precision decimal fast path") { + // DecimalType(10, 2) is well inside p <= 18, so the nested-array `getDecimal` emits the + // unscaled-long fast path (see `emitNestedArrayElementGetter`). A `BigDecimal` UDF argument + // forces Spark's encoder to call `getDecimal(i, 10, 2)` on our nested ArrayData for each + // element, which exercises that code path end to end. + spark.udf.register( + "sumDecArr", + (arr: Seq[java.math.BigDecimal]) => + if (arr == null) null + else { + var acc = java.math.BigDecimal.ZERO + arr.foreach(v => if (v != null) acc = acc.add(v)) + acc + }) + withArrayTable( + "ARRAY", + "(array(1.23, 4.56)), (array(-9.99)), (null), (array())") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumDecArr(a) FROM t")) + } + } + } + + // ============================================================================================= + // StructType + MapType + nested-composition smoke tests. Source tests prove the emitted Java + // is well-shaped; these tests prove Janino compiles it and the runtime roundtrip matches + // Spark. + // ============================================================================================= + + test("ScalaUDF composes with struct-field access reading Struct.age") { + // Keeps the UDF arg scalar (Int) but puts a `GetStructField` under it so the codegen + // dispatcher compiles the struct-input read path (`row.getStruct(0, 2).getInt(1)`). + spark.udf.register("doubleInt", (i: Int) => i * 2) + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42)), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT doubleInt(s.age) FROM t")) + } + } + } + + test("ScalaUDF taking full Struct value (case class arg)") { + // Case-class UDF arguments: test data must not include null top-level rows. + // `ScalaUDF.scalaConverter` applies Spark's `ExpressionEncoder.Deserializer` on every row + // to materialize the case-class instance. The generated deserializer has a + // `newInstance(NameAgePair)` step that throws `EXPRESSION_DECODING_FAILED` on a null input, + // independent of the dispatcher. Case-class UDF tests omit null top-level rows; other + // tests with plain `Seq` / `Map` args can include nulls because the deserializer hands null + // to the UDF body which handles it. + spark.udf.register("fmtPair", (r: NameAgePair) => s"${r.name}:${r.age}") + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42))") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT fmtPair(s) FROM t")) + } + } + } + + test("ScalaUDF returning Struct (case class output)") { + spark.udf.register("makePair", (i: Int) => NameAgePair(s"n$i", i)) + withTypedCol("INT", "1", "2", "3") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT makePair(c) FROM t")) + } + } + } + + test("ScalaUDF taking Map") { + spark.udf.register("sumMap", (m: Map[String, Int]) => if (m == null) -1 else m.values.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP) USING parquet") + sql("INSERT INTO t VALUES (map('a', 1, 'b', 2)), (map()), (null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumMap(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map (primitive key and value)") { + // Map with non-string keys: exercises the primitive-key element getter on the input side + // and the corresponding writer on the output side. Spark's encoder for `Map[Int, Int]` calls + // `getInt(0)` / `getInt(1)` on the entries struct, hitting the kernel's typed scalar getter + // for each side rather than the UTF8 path. + spark.udf.register( + "incValues", + (m: Map[Int, Int]) => if (m == null) null else m.map { case (k, v) => k -> (v + 1) }) + withTable("t") { + sql("CREATE TABLE t (m MAP) USING parquet") + sql("INSERT INTO t VALUES (map(1, 10, 2, 20)), (map()), (null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT incValues(m) FROM t")) + } + } + } + + test("ScalaUDF returning Map") { + spark.udf.register( + "singletonMap", + (s: String, i: Int) => if (s == null) null else Map(s -> i)) + withTable("t") { + sql("CREATE TABLE t (s STRING, i INT) USING parquet") + sql("INSERT INTO t VALUES ('a', 1), ('b', 2), (null, 3)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT singletonMap(s, i) FROM t")) + } + } + } + + test("ScalaUDF taking Map> exercises nested composition") { + spark.udf.register( + "totalLens", + (m: Map[String, Seq[Int]]) => if (m == null) -1 else m.values.flatten.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(1, 2, 3), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT totalLens(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Array> (nested array input + output)") { + // Exercises nested-array input reads and nested-list output writes in one call: the inner + // `InputArray_col0_e` class on the input side and the recursive emitWrite on the output. + spark.udf.register( + "reverseRows", + (arr: Seq[Seq[Int]]) => if (arr == null) null else arr.map(_.reverse)) + withTable("t") { + sql("CREATE TABLE t (a ARRAY>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(array(1, 2, 3), array(4, 5))), " + + "(array(array())), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT reverseRows(a) FROM t")) + } + } + } + + test("ScalaUDF round-trips Struct>") { + // Struct with a complex field on both sides: input reads go through InputStruct_col0 + + // InputArray_col0_f1, output writes through StructVector + ListVector. + // Null top-level rows omitted - case-class arg; see the note on `fmtPair` above. + spark.udf.register( + "growItems", + (r: NameItems) => + if (r == null) null else NameItems(r.name, if (r.items == null) null else r.items :+ 0)) + withTable("t") { + sql("CREATE TABLE t (s STRUCT>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'a', 'items', array(1, 2))), " + + "(named_struct('name', 'b', 'items', array()))") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT growItems(s) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map> (nested value both sides)") { + // Map input read goes through InputMap_col0 + InputArray_col0_v (the complex-value side); + // output write emits MapVector + entries Struct + per-value ListVector inside the map's + // entries struct. + spark.udf.register( + "sortValues", + (m: Map[String, Seq[Int]]) => + if (m == null) null + else m.map { case (k, v) => k -> (if (v == null) null else v.sorted) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(3, 1, 2), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sortValues(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map>") { + // Struct value inside a map, both sides. Null top-level rows omitted - the map value is a + // case class; see the note on `fmtPair` above. + spark.udf.register( + "tagValues", + (m: Map[String, XyPair]) => + if (m == null) null + else + m.map { case (k, v) => k -> (if (v == null) null else XyPair(v.x + 1, s"<${v.y}>")) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', named_struct('x', 1, 'y', 'one'))), " + + "(map())") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT tagValues(m) FROM t")) + } + } + } + + // ============================================================================================= + // Regression tests pinning specific kernel bugs first surfaced in CometCodegenDispatchFuzzSuite. + // Each is the smallest deterministic input that triggered the bug; kept post-fix as a guard + // against future regression. + // ============================================================================================= + + test("array_distinct on Array> retains element identity across hash set") { + // Fuzz signal: cardinality(array_distinct(arr_of_struct)) returns 1 where Spark returns 2. + // Hypothesis: the kernel's InputStruct wrapper backing array_distinct's element reads is + // reused without resetting per-element state, so every hashed element looks identical and + // distinct collapses the array to a single entry. + spark.udf.register("idIntDistinct", (i: Int) => i) + withTable("t") { + sql("CREATE TABLE t (s ARRAY>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 1, 'b', 'x'))), " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 2, 'b', 'y'))), " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 2, 'b', 'y'), " + + "named_struct('a', 1, 'b', 'x')))") + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT idIntDistinct(cardinality(array_distinct(s))) FROM t")) + } + } + } + + test("array_max(flatten(arr)) on Array> with mixed null inner arrays") { + // Fuzz signal: array_max(flatten(arr)) returns empty byte arrays where Spark returns the + // actual max binary, with the empties sorting to the front of the output. Pattern points at + // cross-batch state pollution. Generate 100 rows of varied outer/inner shape, longer + // binaries, mixed nulls; force multiple batches with a small batch size. + spark.udf.register("idBinFlat", (b: Array[Byte]) => b) + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "16") { + withTable("t") { + sql("CREATE TABLE t (a ARRAY>) USING parquet") + val rows = (0 until 100).map { i => + if (i % 11 == 0) { + "(NULL)" + } else { + val outerSize = (i % 5) + 1 + val inners = (0 until outerSize).map { j => + val pick = (i * 7 + j) % 13 + if (pick == 0) "array()" + else if (pick == 1) "NULL" + else { + val innerSize = ((i + j) % 4) + 1 + val bytes = (0 until innerSize).map { k => + val len = ((i + j + k) % 8) + 1 + val hex = (0 until len) + .map(b => f"${(i * 13 + j * 17 + k * 5 + b) & 0xff}%02x") + .mkString + s"X'$hex'" + } + "array(" + bytes.mkString(", ") + ")" + } + } + s"(array(${inners.mkString(", ")}))" + } + } + sql(s"INSERT INTO t VALUES ${rows.mkString(", ")}") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT idBinFlat(array_max(flatten(a))) FROM t")) + } + } + } + } + + // ============================================================================================= + // Regression tests for nested reference-type getter null-handling. Spark's + // `CodeGenerator.setArrayElement` (called from e.g. `Flatten.doGenCode`) only emits an + // `isNullAt` check before `array.update(i, getX(j))` when the element is a Java primitive + // (`int`/`long`/etc.). For reference-typed elements (Binary, String, Decimal, Struct, Array, + // Map) it emits `array.update(i, getX(j))` unconditionally, relying on the source's getter to + // return `null` for null positions itself (Spark's own `ColumnarArray.getBinary` does + // `if (isNullAt(...)) return null;`). Our nested `InputArray_*.getX` getters do not honor that + // contract, so any inner null at a reference-typed position becomes an empty-bytes / empty- + // string / garbage-decimal / non-null-shell value in the flattened output. Each test below + // pins one reference-type variant so the fix can be verified per type. + // ============================================================================================= + + test("array_max(flatten(arr)) on Array> with null inner Binary returns null") { + spark.udf.register("idBin", (b: Array[Byte]) => b) + withArrayTable( + "ARRAY>", + "(array(array(NULL))), " + + "(array(array(NULL, NULL))), " + + "(array(array(), array(NULL)))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT idBin(array_max(flatten(a))) FROM t")) + } + } + } + + test("array_max(flatten(arr)) on Array> with null inner String returns null") { + spark.udf.register("idStr", (s: String) => s) + withArrayTable( + "ARRAY>", + "(array(array(NULL))), " + + "(array(array(NULL, NULL))), " + + "(array(array(), array(NULL)))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT idStr(array_max(flatten(a))) FROM t")) + } + } + } + + test( + "array_max(flatten(arr)) on Array> with null inner Decimal " + + "(short-precision fast path)") { + spark.udf.register("idDec10", (d: java.math.BigDecimal) => d) + withArrayTable( + "ARRAY>", + "(array(array(CAST(NULL AS DECIMAL(10, 2))))), " + + "(array(array(" + + "CAST(NULL AS DECIMAL(10, 2)), CAST(NULL AS DECIMAL(10, 2))))), " + + "(array(array(), array(CAST(NULL AS DECIMAL(10, 2)))))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT idDec10(array_max(flatten(a))) FROM t")) + } + } + } + + test( + "array_max(flatten(arr)) on Array> with null inner Decimal " + + "(long-precision slow path)") { + spark.udf.register("idDec30", (d: java.math.BigDecimal) => d) + withArrayTable( + "ARRAY>", + "(array(array(CAST(NULL AS DECIMAL(30, 2))))), " + + "(array(array(" + + "CAST(NULL AS DECIMAL(30, 2)), CAST(NULL AS DECIMAL(30, 2))))), " + + "(array(array(), array(CAST(NULL AS DECIMAL(30, 2)))))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT idDec30(array_max(flatten(a))) FROM t")) + } + } + } + + // Note: a runtime regression test for nullable nested `getStruct` / `getArray` / `getMap` would + // need a + // non-HOF expression that reads null elements after `flatten`. Spark's optimizer rules + // (`SimplifyExtractValueOps` and friends) tend to rewrite the obvious candidates + // (`element_at(flatten(arr), 1).x`, `flatten(arr)[i].x`) into shapes our dispatcher rejects + // without a clean reason, and the only iteration paths over complex elements without + // simplification go through HOFs (`array_filter`, `transform`) which our `canHandle` rejects + // (TODO(hof-lambdas) on `CometBatchKernelCodegen`). Static coverage of the emitter for these + // three getters lives in `CometCodegenSourceSuite` instead. +} + +/** + * Case class used by the struct-input / struct-output smoke tests. Must be declared at file scope + * (not inside the test class) so Spark's TypeTag-based UDF encoder can resolve the Spark + * `StructType` schema from the Scala class. + */ +private case class NameAgePair(name: String, age: Int) + +private case class NameItems(name: String, items: Seq[Int]) + +private case class XyPair(x: Int, y: String) diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala new file mode 100644 index 0000000000..c6e42d432b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -0,0 +1,1086 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.types._ + +import org.apache.comet.codegen.CometBatchKernelCodegen +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} + +// Resolve Arrow vector classes through the codegen object so tests see the same `Class` objects +// the shaded `common` module sees. A direct `classOf[org.apache.arrow.vector.VarCharVector]` here +// would be the unshaded class from the test classpath, which is not `==` to the shaded class the +// production pattern-matches against. + +/** + * Generated-source inspection tests. These exercise `CometBatchKernelCodegen.generateSource` and + * assert on the emitted Java directly, without invoking Janino. The goal is to catch regressions + * in the optimizations we claim the dispatcher applies: + * + * - `NullIntolerant` short-circuit wraps `ev.code` in `if (any-input-null) { setNull; } else { + * ev.code; write; }`. + * - Non-nullable column declaration emits `return false;` from `isNullAt(ord)` and, when the + * dispatcher rewrites the `BoundReference`, Spark's `doGenCode` stops emitting its own + * `row.isNullAt(ord)` probe. + * - Zero-copy string reads route through `UTF8String.fromAddress`. + * + * These are the smallest durable tests that the claimed optimizations actually reach the + * generated Java, and they document the shapes future contributors should preserve. + */ +class CometCodegenSourceSuite extends AnyFunSuite { + + private val varCharVectorClass = + CometBatchKernelCodegen.vectorClassBySimpleName("VarCharVector") + + private val nullableString = ArrowColumnSpec(varCharVectorClass, nullable = true) + private val nonNullableString = ArrowColumnSpec(varCharVectorClass, nullable = false) + + private def gen( + expr: org.apache.spark.sql.catalyst.expressions.Expression, + specs: ArrowColumnSpec*): String = + CometBatchKernelCodegen.generateSource(expr, specs.toIndexedSeq).body + + test("non-nullable column emits literal-false isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + src.contains("case 0: return false;"), + s"expected non-nullable isNullAt to return literal false; got:\n$src") + } + + test("non-nullable BoundReference elides Spark's own isNullAt probe in the expression body") { + // When the BoundReference carries `nullable=false`, Spark's `doGenCode` skips the + // `row.isNullAt(ord)` branch at source level. This is the payoff of the tree-rewrite in + // `CometScalaUDFCodegen.lookupOrCompile`: subsequent expressions over the same column + // compile to tighter source rather than relying on JIT to constant-fold `isNullAt`. + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + !src.contains("row.isNullAt(0)"), + s"expected Spark's BoundReference null probe to be elided; got:\n$src") + } + + test("nullable column emits delegated isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("case 0: return this.col0.isNull(this.rowIdx);"), + s"expected nullable isNullAt to delegate to the Arrow vector; got:\n$src") + } + + test("VarCharVector getUTF8String uses zero-copy fromAddress") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("org.apache.spark.unsafe.types.UTF8String"), + s"expected UTF8String reference; got:\n$src") + assert(src.contains(".fromAddress("), s"expected zero-copy fromAddress read; got:\n$src") + } + + test("NullIntolerant expression emits input-null short-circuit before ev.code") { + // Upper is NullIntolerant (null in -> null out). Expect the default body to prepend + // `if (this.col0.isNull(i)) { setNull; } else { ... }` so null rows skip the whole + // expression eval, not just the setNull write. + val expr = Upper(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("this.col0.isNull(i)"), + s"expected NullIntolerant short-circuit on input ordinal 0; got:\n$src") + assert( + src.contains("output.setNull(i);"), + s"expected setNull emission for short-circuited null rows; got:\n$src") + } + + test("NullIntolerant short-circuit emitted when every node is NullIntolerant") { + // Length(Upper(BoundReference)): Length is NullIntolerant, Upper is NullIntolerant, + // BoundReference is a leaf. Every path from a leaf to the root propagates nulls, so the + // short-circuit heuristic ("any input null -> output null") holds. + val expr = Length(Upper(BoundReference(0, StringType, nullable = true))) + val src = gen(expr, nullableString) + assert( + src.contains("if (this.col0.isNull(i))"), + s"expected short-circuit on col0 when every node is NullIntolerant; got:\n$src") + } + + test("NullIntolerant short-circuit skipped when a non-NullIntolerant node breaks the chain") { + // Concat is not NullIntolerant; null in some args doesn't necessarily produce a null + // result. The short-circuit heuristic would be incorrect here (short-circuiting on c0 or c1 + // being null would skip evaluation, but Concat's null handling differs). Expect the + // default path without the `if (colX.isNull(i) || colY.isNull(i))` wrapper, letting Spark's + // own `ev.code` handle nulls correctly. + val nullable1 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val nullable2 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val expr = Length( + Concat( + Seq( + BoundReference(0, StringType, nullable = true), + BoundReference(1, StringType, nullable = true)))) + val src = gen(expr, nullable1, nullable2) + assert( + !src.contains("this.col0.isNull(i) || this.col1.isNull(i)"), + "expected no pre-null short-circuit when Concat breaks the NullIntolerant chain; " + + s"got:\n$src") + } + + test("canHandle rejects CodegenFallback expressions") { + val expr = FakeCodegenFallback(BoundReference(0, StringType, nullable = true)) + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject CodegenFallback") + assert( + reason.get.contains("FakeCodegenFallback"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("canHandle accepts Nondeterministic expressions (per-partition kernel handles state)") { + // Per-partition kernel instance caching in `CometScalaUDFCodegen.ensureKernel` advances + // mutable state across batches in one partition, so Rand/Uuid/etc. produce the expected + // sequences. The previous canHandle rejection was conservative; with that caching in + // place, accepting Nondeterministic is correct. + val expr = FakeNondeterministic() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isEmpty, s"expected canHandle to accept Nondeterministic; got $reason") + } + + test("canHandle rejects Unevaluable expressions") { + val expr = FakeUnevaluable() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject Unevaluable") + assert( + reason.get.contains("FakeUnevaluable"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("CSE collapses a repeated subtree to one evaluation in the generated body") { + // `Add(Length(Upper(c0)), Length(Upper(c0)))` has `Length(Upper(c0))` as a common subtree. + // Length.doGenCode emits `$value.numChars()` on every Spark version the project targets, + // which makes it a stable activation marker. Upper's own doGenCode text drifts across + // versions (Spark 3.5 emits `UTF8String.toUpperCase()`, Spark 4 emits + // `CollationSupport.Upper.exec*` via collation-aware codegen), so we avoid it as a marker. + // When CSE fires, `Length(Upper(c0))` compiles into one `subExpr_*` helper whose body calls + // `numChars()` once; both uses in the `Add` read the cached result from mutable state. + // Without CSE, each Add child would emit its own `numChars()` call. + val upperOrd0 = Upper(BoundReference(0, StringType, nullable = true)) + val lenUpper = Length(upperOrd0) + val expr = Add(lenUpper, lenUpper) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val occurrences = "\\.numChars\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 1, + "expected CSE to collapse repeated Length evaluation to 1 numChars() call, " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + // Additional proof: CSE emitted a `subExpr_` helper method. Without CSE the generator would + // have inlined the repeated subtree into the main body with no helper at all. + assert( + result.body.contains("subExpr_0(row)"), + s"expected CSE helper invocation; got:\n${CodeFormatter.format(result.code)}") + } + + test("CSE does not fire on non-deterministic expressions (regression guard)") { + // `Add(Rand(0), Rand(0))` is two structurally identical non-deterministic subtrees. CSE must + // not collapse them: each Rand call must produce an independent draw. Spark's CSE + // (`EquivalentExpressions.updateExprInMap`) filters non-deterministic expressions via + // `expr.deterministic`, so the two Rands stay separate. This test is a regression guard + // against Spark ever relaxing that check and against us accidentally applying CSE outside + // the `generateExpressions` path (which respects the filter). `Rand.doGenCode` emits one + // `$rng.nextDouble()` call per evaluation, so two Rands produce two `.nextDouble()` calls + // in the body; one-call output would indicate incorrect CSE. + val expr = Add(Rand(Literal(0L, LongType)), Rand(Literal(0L, LongType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty) + val occurrences = "\\.nextDouble\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 2, + "expected two independent Rand evaluations (no CSE on nondeterministic), " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + } + + test("DecimalVector getDecimal specializes to unscaled-long fast path for short precision") { + // Mirrors Spark's `UnsafeRow.getDecimal` split at `Decimal.MAX_LONG_DIGITS` (18), done at + // codegen time rather than at runtime. The dispatcher reads the `BoundReference`'s + // `DecimalType` at source-generation time and emits only the fast-path branch when + // `precision <= 18`. The fast path reads the low 8 bytes of the 16-byte Arrow decimal128 + // slot directly as a signed long via `ArrowBuf.getLong` and wraps with + // `Decimal.createUnsafe`, avoiding the `BigDecimal` allocation `DecimalVector.getObject` + // would perform. For precision > 18 the generator emits only the slow-path branch + // (`getObject + Decimal.apply`); see the companion test below. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".createUnsafe("), + "expected Decimal.createUnsafe call on fast path; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains("Platform.getLong(") && + result.body.contains("this.col0_valueAddr"), + "expected unsafe Platform.getLong against cached valueAddr; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains(".getObject("), + "expected specialized fast path (no BigDecimal fallback branch in source); got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known short-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector getDecimal specializes to BigDecimal slow path for long precision") { + // Companion to the fast-path test. For `DecimalType(p, s)` with `p > 18`, the unscaled value + // can exceed 64 bits, so the generator emits only the `getObject + Decimal.apply` branch. + // The fast path markers must be absent so the generated source is minimal for this column. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".getObject(") && result.body.contains(".apply("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".createUnsafe("), + "expected no fast-path emission for long-precision column; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known long-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses unscaled-long fast path for short-precision output") { + // The output writer specializes on the root expression's DecimalType precision. For + // precision <= 18 the Decimal's unscaled long is passed directly to + // `DecimalVector.setSafe(int, long)`, avoiding the BigDecimal allocation that + // `toJavaBigDecimal()` performs. Use a simple expression that produces a DecimalType output: + // `BoundReference(0, DecimalType(18, 2))` has output type DecimalType(18, 2), which is what + // the generator specializes on. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toUnscaledLong()"), + s"expected toUnscaledLong call on fast path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toJavaBigDecimal("), + "expected no BigDecimal allocation for short-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses BigDecimal slow path for long-precision output") { + // Companion to the fast-path output test. Precision > 18 can have unscaled values exceeding + // 64 bits, so the writer must fall back to the BigDecimal path. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toJavaBigDecimal("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toUnscaledLong()"), + "expected no unscaled-long write for long-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("VarCharVector setSafe uses on-heap UTF8String shortcut") { + // The UTF8String output writer avoids the `byte[] b = $value.getBytes()` allocation when + // the UTF8String is on-heap by passing its backing byte[] directly to + // `VarCharVector.setSafe(int, byte[], int, int)`. Spark's string functions allocate their + // result on-heap, so this path hits for typical string expressions. Off-heap fallback + // (for passthrough of zero-copy input reads) stays as the else branch. + // + // Markers: `getBaseObject()` (inspecting the backing), `instanceof byte[]` (the branch), + // and `Platform.BYTE_ARRAY_OFFSET` (the on-heap offset math). + val expr = Upper(BoundReference(0, StringType, nullable = true)) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + assert( + result.body.contains(".getBaseObject()"), + s"expected UTF8String.getBaseObject call; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("instanceof byte[]"), + s"expected on-heap instanceof branch; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("Platform.BYTE_ARRAY_OFFSET"), + "expected on-heap offset math via Platform.BYTE_ARRAY_OFFSET; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains(".getBytes()"), + s"expected off-heap getBytes fallback; got:\n${CodeFormatter.format(result.code)}") + } + + test("non-nullable root expression omits the `if (isNull)` branch in default body") { + // When the bound expression claims `nullable = false`, the default body drops the + // `if (ev.isNull) output.setNull(i);` guard entirely. `Length` on a non-nullable column is + // itself non-nullable (Length.nullable = child.nullable = false), so the writer goes + // straight to the setSafe/set call. This test uses a non-NullIntolerant-short-circuit + // shape by wrapping Length in Coalesce, so we exercise the default branch of defaultBody + // rather than the NullIntolerant one. Actually, Length is NullIntolerant, so the NI branch + // fires; use an expression that's non-nullable but whose tree is not fully NullIntolerant + // to hit the default branch. `Coalesce(Seq(Length(col_non_null), Literal(0)))` has + // nullable=false (Coalesce is non-null when any child is) and Coalesce itself is not + // NullIntolerant, so the default branch runs. Assert `setNull` is absent. + val expr = Coalesce( + Seq(Length(BoundReference(0, StringType, nullable = false)), Literal(0, IntegerType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nonNullableString)) + assert( + !result.body.contains("output.setNull(i);"), + "expected no setNull for a non-nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("nullable root expression keeps the `if (isNull)` branch in default body") { + // Baseline: when the root expression is nullable, the setNull branch must still be emitted. + // Uses Coalesce with a nullable child so the Coalesce itself remains nullable. Guards the + // NonNullableOutputShortCircuit optimization against over-firing. + val expr = Coalesce( + Seq( + Length(BoundReference(0, StringType, nullable = true)), + BoundReference(1, IntegerType, nullable = true))) + val result = CometBatchKernelCodegen.generateSource( + expr, + IndexedSeq( + nullableString, + ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true))) + assert( + result.body.contains("output.setNull(i);"), + "expected setNull branch for a nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("ArrayType(StringType) output emits ListVector startNewValue/endValue recursion") { + // CreateArray over a BoundReference(StringType) produces ArrayType(StringType). emitWrite's + // ArrayType case should emit: + // - ListVector cast of output + // - child VarCharVector extraction via getDataVector + // - startNewValue + per-element loop + endValue + // - the per-element write recursing into the StringType case (which uses the UTF8 on-heap + // shortcut marker `instanceof byte[]`) + // Focus markers: ListVector cast, VarCharVector child cast, startNewValue, endValue, and + // the inner UTF8 shortcut branch. + val expr = + CreateArray( + Seq(BoundReference(0, StringType, nullable = true), Literal.create("x", StringType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val src = result.body + val formatted = CodeFormatter.format(result.code) + assert(src.contains("ListVector"), s"expected ListVector in emitted body; got:\n$formatted") + assert(src.contains(".startNewValue("), s"expected startNewValue call; got:\n$formatted") + assert(src.contains(".endValue("), s"expected endValue call; got:\n$formatted") + assert( + src.contains(".getDataVector()"), + s"expected child vector extraction; got:\n$formatted") + assert( + src.contains("instanceof byte[]"), + s"expected inner UTF8 on-heap shortcut for string elements; got:\n$formatted") + } + + test("MapType output emits MapVector startNewValue/endValue + per-pair writes") { + // CreateMap produces MapType(k, v). emitWrite's MapType case should emit: + // - MapVector cast of output + // - entries StructVector extraction + // - typed key / value child casts via getChildByOrdinal(0) / (1) + // - startNewValue / endValue bracketing + // - setIndexDefined on each struct entry + // - keyArray() / valueArray() retrieval from the MapData source + // Non-null literals here mean `valueContainsNull == false`, so the value-side null guard is + // elided; the existence and elision of the `isNullAt` guard are exercised by the dedicated + // [[NullableElementElision]] tests below. + val expr = CreateMap( + Seq( + Literal.create("a", StringType), + Literal(1, IntegerType), + Literal.create("b", StringType), + Literal(2, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + Seq( + "MapVector", + "StructVector", + ".startNewValue(", + ".endValue(", + ".setIndexDefined(", + ".keyArray()", + ".valueArray()").foreach { marker => + assert(src.contains(marker), s"expected $marker in MapType output emission; got:\n$src") + } + } + + test("ArrayType output elides isNullAt on the element loop when containsNull is false") { + // CreateArray over only-non-null Literals produces ArrayType(elementType, containsNull=false). + // The element write should drop the `arr.isNullAt(j)` guard at source level rather than + // relying on JIT folding. + val expr = CreateArray(Seq(Literal(1, IntegerType), Literal(2, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + assert( + !src.contains(".isNullAt("), + s"expected no isNullAt in element loop when containsNull=false; got:\n$src") + assert(src.contains(".startNewValue("), s"expected startNewValue still emitted; got:\n$src") + } + + test("ArrayType output keeps isNullAt on the element loop when containsNull is true") { + // CreateArray with at least one nullable child produces containsNull=true; the element + // null-guard must survive. + val expr = + CreateArray(Seq(BoundReference(0, IntegerType, nullable = true), Literal(2, IntegerType))) + val intSpec = ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body + assert( + src.contains(".isNullAt("), + s"expected isNullAt in element loop when containsNull=true; got:\n$src") + } + + test("MapType output keeps value isNullAt when valueContainsNull is true") { + // ElementAt with safe-index selection produces a nullable Int; wrapping the value column in + // a CreateMap with that nullable Int makes valueContainsNull=true. The value-side null-guard + // must survive. + val expr = + CreateMap( + Seq(Literal.create("a", StringType), BoundReference(0, IntegerType, nullable = true))) + val intSpec = ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body + assert( + src.contains(".isNullAt("), + s"expected isNullAt on the value-write branch when valueContainsNull=true; got:\n$src") + } + + test("ArrayType(StringType) input emits InputArray_col0 nested class with UTF8 child getter") { + // Array input with string elements: the kernel must expose a `getArray(0)` that hands Spark's + // `doGenCode` an `ArrayData` view onto the Arrow `ListVector`'s child `VarCharVector`. + // Markers: the nested class declaration with a slice constructor, the typed child getter + // using `fromAddress`, and a `getArray` switch on the ordinal that allocates a fresh view. + val varCharChildSpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = varCharChildSpec) + val expr = Size(BoundReference(0, ArrayType(StringType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("class InputArray_col0"), + s"expected nested ArrayData class for array col0; got:\n$src") + assert( + src.contains("InputArray_col0(int startIdx, int len)"), + s"expected InputArray_col0 to take a slice via constructor; got:\n$src") + assert( + src.contains("getElementStartIndex(") && src.contains("getElementEndIndex("), + s"expected list-offset reads at the call site; got:\n$src") + assert( + src.contains("public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i)"), + s"expected element-type-specific UTF8String getter; got:\n$src") + assert( + src.contains(".fromAddress("), + s"expected zero-copy UTF8 read inside the nested ArrayData; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)"), + s"expected kernel-level getArray switch; got:\n$src") + assert( + src.contains("return new InputArray_col0("), + s"expected getArray to allocate a fresh InputArray_col0 view; got:\n$src") + } + + test("ArrayType(IntegerType) input emits primitive int getter in nested class") { + val intChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = intChildSpec) + val expr = Size(BoundReference(0, ArrayType(IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("public int getInt(int i)"), + s"expected primitive int getter on nested array class; got:\n$src") + // Scalar-element fast path reads directly off the typed child vector; no BigDecimal / + // fromAddress scaffolding should leak in. + assert( + !src.contains(".fromAddress("), + s"int element getter should not wrap with UTF8 fromAddress; got:\n$src") + } + + test( + "ArrayType(DecimalType) short-precision input emits decimal128 fast-path via getLong in " + + "nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(10, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(10, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + // Fast path markers: reads the low 8 bytes of the decimal128 slot via getLong + createUnsafe. + // The slow path would go through getObject + Decimal.apply. + assert( + src.contains(".getLong(") && src.contains(".createUnsafe("), + s"expected decimal-input short-precision fast path in nested class; got:\n$src") + assert( + !src.contains(".getObject("), + s"short-precision decimal element should not use BigDecimal slow path; got:\n$src") + } + + test("ArrayType(DecimalType) long-precision input emits BigDecimal slow path in nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(30, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(30, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains(".getObject(") && src.contains("Decimal$.MODULE$"), + s"expected BigDecimal slow path for p>18 element; got:\n$src") + } + + // ============================================================================================ + // Nested-type tests. Each case verifies that a complex-within-complex shape emits a full + // nested-class tree (outer + inner), wired together through the path-suffix naming + // convention: `_e` for array element, `_f${fi}` for struct field fi. Scalar-element / scalar- + // field leaves reuse the typed-getter templates already covered by the single-depth tests. + // ============================================================================================ + + private def generate(expr: Expression, specs: IndexedSeq[ArrowColumnSpec]): String = + CometBatchKernelCodegen.generateSource(expr, specs).body + + test("Array> emits outer + inner array classes with fresh inner allocation") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = innerArray) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputArray_col0_e "), + s"expected both outer and inner array classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_e("), + s"expected outer class to allocate a fresh inner array view per call; got:\n$src") + assert( + src.contains("public int getInt(int i)"), + s"expected innermost scalar getter for IntegerType element; got:\n$src") + } + + test("Array> emits array class allocating fresh InputStruct_col0_e") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + element = innerStruct) + val elemType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val expr = Size(BoundReference(0, ArrayType(elemType), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputStruct_col0_e "), + s"expected array-of-struct nested classes; got:\n$src") + assert( + src.contains("return new InputStruct_col0_e(startIndex + i)"), + s"expected array getStruct to allocate a fresh inner struct view; got:\n$src") + } + + test("Struct> emits outer + inner struct classes") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "s", + StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + nullable = true, + innerStruct))) + val innerType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val outerType = StructType(Seq(StructField("s", innerType, nullable = true)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputStruct_col0_f0 "), + s"expected outer + inner struct classes; got:\n$src") + assert( + src.contains("return new InputStruct_col0_f0(this.rowIdx)"), + s"expected outer struct getStruct to allocate a fresh inner struct view; got:\n$src") + assert( + src.contains("public int getInt(int ordinal)"), + s"expected innermost getInt on InputStruct_col0_f0; got:\n$src") + } + + test("Struct> emits struct class allocating fresh InputArray_col0_f0") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = true, innerArray))) + val structType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = true)).toArray) + val expr = Size(GetStructField(BoundReference(0, structType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputArray_col0_f0 "), + s"expected struct-of-array nested classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_f0("), + s"expected struct getArray to allocate a fresh inner array view; got:\n$src") + } + + test("Map emits InputMap_col0 + keyArray / valueArray views") { + val keySpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = StringType, + valueSparkType = IntegerType, + key = keySpec, + value = valueSpec) + val expr = Size(BoundReference(0, MapType(StringType, IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + assert( + src.contains("class InputMap_col0 "), + s"expected InputMap_col0 nested class; got:\n$src") + assert( + src.contains("class InputArray_col0_k ") && src.contains("class InputArray_col0_v "), + s"expected key/value array view classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_k(this.startIndex, this.length)"), + s"expected keyArray to allocate a fresh view over the map slice; got:\n$src") + assert( + src.contains("return new InputArray_col0_v(this.startIndex, this.length)"), + s"expected valueArray to allocate a fresh view over the map slice; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)"), + s"expected kernel-level getMap switch; got:\n$src") + assert( + src.contains("return new InputMap_col0("), + s"expected getMap to allocate a fresh InputMap_col0 view; got:\n$src") + } + + test("Map, Array> emits complex key and complex value views") { + val keyElem = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val keyArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = keyElem) + val valueElem = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = valueElem) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = ArrayType(IntegerType), + valueSparkType = ArrayType(StringType), + key = keyArraySpec, + value = valueArraySpec) + val expr = Size( + BoundReference(0, MapType(ArrayType(IntegerType), ArrayType(StringType)), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + // Full chain of nested classes should appear: top-level map view, the key/value array + // views, and the inner array classes for each complex key/value element. + Seq( + "class InputMap_col0 ", + "class InputArray_col0_k ", + "class InputArray_col0_v ", + "class InputArray_col0_k_e ", + "class InputArray_col0_v_e ").foreach { marker => + assert(src.contains(marker), s"expected $marker in emission; got:\n$src") + } + } + + // ============================================================================================ + // Null-guard emission for nested reference-typed getters. Spark's + // `CodeGenerator.setArrayElement` only emits an `isNullAt` check before `update(i, getX(j))` + // for primitive elements. For reference types (Decimal / String / Binary / Struct / Array / + // Map) it relies on the source's `getX` to return null on null positions itself. The emitter + // honors this by prepending `if (isNullAt(...)) return null;` to those getters when the + // element / field is nullable, eliding the guard otherwise. + // + // Runtime regression coverage for the leaf reference types lives in + // `CometCodegenDispatchSmokeSuite` (Binary / String / Decimal short / Decimal long REPROs). + // The complex types (Struct / Array / Map) can't be runtime-tested without HOFs (see + // TODO(hof-lambdas) on `CometBatchKernelCodegen.canHandle`), so they live here. + // ============================================================================================ + + private val nullableIntStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + private val nullableIntStructType = + StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + + private val nullableIntArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + + private val nullableIntStrMap = MapColumnSpec( + nullable = true, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = true)) + + test("nested array of nullable Struct emits null guard before allocating InputStruct view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = nullableIntStructType, + element = nullableIntStruct) + val expr = Size(BoundReference(0, ArrayType(nullableIntStructType), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputStruct_col0_e(startIndex + i)"), + s"expected null guard and InputStruct alloc on nullable Struct element; got:\n$src") + } + + test("nested array of non-nullable Struct elides null guard") { + // Fully non-nullable inner spec: outer struct nullable=false AND inner Int field + // nullable=false. Without the inner field also being non-nullable the inner + // primitive-Int getter wouldn't emit a guard anyway (we only guard reference types), but + // making everything non-nullable means the broad `!src.contains("if (isNullAt(...))")` + // assertion verifies "no guards anywhere" rather than passing because the inner happens + // to be a primitive we don't guard. + val nonNullableInner = StructColumnSpec( + nullable = false, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = false, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)))) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = nullableIntStructType, + element = nonNullableInner) + val expr = Size(BoundReference(0, ArrayType(nullableIntStructType), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputStruct_col0_e(startIndex + i)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;") && + !src.contains("if (isNullAt(0)) return null;"), + s"expected no null guard anywhere on fully non-nullable Struct element; got:\n$src") + } + + test( + "nested array of nullable inner Array emits null guard before allocating InputArray view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = nullableIntArray) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputArray_col0_e(__s, __e - __s)"), + s"expected null guard and InputArray alloc on nullable Array element; got:\n$src") + } + + test("nested array of non-nullable inner Array elides null guard") { + val nonNullableInner = ArrayColumnSpec( + nullable = false, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = nonNullableInner) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputArray_col0_e(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard on non-nullable inner Array element; got:\n$src") + } + + test("nested array of nullable Map emits null guard before allocating InputMap view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(IntegerType, StringType), + element = nullableIntStrMap) + val expr = + Size(BoundReference(0, ArrayType(MapType(IntegerType, StringType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputMap_col0_e(__s, __e - __s)"), + s"expected null guard and InputMap alloc on nullable Map element; got:\n$src") + } + + test("nested array of non-nullable Map elides null guard") { + val nonNullableMap = MapColumnSpec( + nullable = false, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = false)) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(IntegerType, StringType), + element = nonNullableMap) + val expr = + Size(BoundReference(0, ArrayType(MapType(IntegerType, StringType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputMap_col0_e(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard on non-nullable Map element; got:\n$src") + } + + test("struct with nullable struct field emits null guard in getStruct(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("s", nullableIntStructType, nullable = true, nullableIntStruct))) + val outerType = + StructType(Seq(StructField("s", nullableIntStructType, nullable = true)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputStruct_col0_f0(this.rowIdx)"), + s"expected null guard and InputStruct alloc for nullable struct field; got:\n$src") + } + + test("struct with non-nullable struct field elides null guard") { + val nonNullableInner = StructColumnSpec( + nullable = false, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = false, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)))) + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("s", nullableIntStructType, nullable = false, nonNullableInner))) + val outerType = + StructType(Seq(StructField("s", nullableIntStructType, nullable = false)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputStruct_col0_f0(this.rowIdx)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;") && + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard anywhere on fully non-nullable struct field; got:\n$src") + } + + test("struct with nullable array field emits null guard in getArray(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = true, nullableIntArray))) + val outerType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = true)).toArray) + val expr = + Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputArray_col0_f0(__s, __e - __s)"), + s"expected null guard and InputArray alloc for nullable array field; got:\n$src") + } + + test("struct with non-nullable array field elides null guard") { + val nonNullableInner = ArrayColumnSpec( + nullable = false, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = false, nonNullableInner))) + val outerType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = false)).toArray) + val expr = + Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputArray_col0_f0(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;") && + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard anywhere on fully non-nullable array field; got:\n$src") + } + + test("struct with nullable map field emits null guard in getMap(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "m", + MapType(IntegerType, StringType), + nullable = true, + nullableIntStrMap))) + val outerType = + StructType(Seq(StructField("m", MapType(IntegerType, StringType), nullable = true)).toArray) + val expr = Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("m"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputMap_col0_f0(__s, __e - __s)"), + s"expected null guard and InputMap alloc for nullable map field; got:\n$src") + } + + test("struct with non-nullable map field elides null guard") { + val nonNullableMap = MapColumnSpec( + nullable = false, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = false)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec("m", MapType(IntegerType, StringType), nullable = false, nonNullableMap))) + val outerType = StructType( + Seq(StructField("m", MapType(IntegerType, StringType), nullable = false)).toArray) + val expr = Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("m"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputMap_col0_f0(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;"), + s"expected no null guard on non-nullable map field; got:\n$src") + } +} + +/** + * Minimal fake expressions for the `canHandle` rejection tests. Each opts into one of the marker + * traits whose presence forces a serde-level fallback. Bodies are unreachable; `canHandle` walks + * the tree structurally. + */ +private case class FakeCodegenFallback(child: Expression) + extends Expression + with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = null + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = copy(child = newChildren.head) +} + +private case class FakeNondeterministic() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = true + + override def dataType: DataType = IntegerType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Any = 0 + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("test fake; never reaches codegen") +} + +private case class FakeUnevaluable() extends LeafExpression with Unevaluable { + override def nullable: Boolean = true + + override def dataType: DataType = IntegerType +} diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala index 9622960932..4a8629a71e 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala @@ -65,11 +65,11 @@ class CometIcebergRewriteActionSuite extends CometTestBase with CometIcebergTest } // Single-column zOrder is bit-pattern-equivalent to a natural sort (no second dimension to - // interleave with), so we expect the same ascending output as the sort test. The shuffle here - // is CometColumnarExchange rather than CometExchange because the z-value column is computed - // by a Spark Project (Iceberg's INTERLEAVE_BYTES / INT_ORDERED_BYTES are not recognised by - // Comet), so the path crosses a JVM-row boundary before the shuffle. - test("single-column zOrder rewrite runs scan, columnar exchange, and sort natively in Comet") { + // interleave with), so we expect the same ascending output as the sort test. Iceberg's + // `INT_ORDERED_BYTES` / `INTERLEAVE_BYTES` are `ScalaUDF`s that route through Comet's codegen + // dispatcher, so the project stays native and the shuffle picks `CometExchange` / + // `CometNativeShuffle` rather than the columnar-row roundtrip path. + test("single-column zOrder rewrite runs scan, native exchange, and sort natively in Comet") { runRewriteTest( RewriteCase( table = s"$catalog.db.zorder_test", @@ -77,7 +77,7 @@ class CometIcebergRewriteActionSuite extends CometTestBase with CometIcebergTest verifyDataAfter = assertSortedById, verifyPlans = { rewritePlans => assertReadsAreComet(rewritePlans) - assertOperator(rewritePlans, "CometColumnarExchange") + assertOperator(rewritePlans, "CometExchange") assertOperator(rewritePlans, "CometSort") })) } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala new file mode 100644 index 0000000000..a5c40c7b25 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +/** + * Benchmark user-registered ScalaUDFs composed in trees, comparing the codegen dispatcher to the + * "feature off" baseline (where a user UDF forces the containing operator to Spark) and to + * Comet's native built-ins that are functionally equivalent. + * + * Four modes per composition: + * + * - '''Spark''': all Comet disabled. + * - '''Comet (native built-ins)''': the composition rewritten using Comet-native Spark + * built-ins (`upper`, `lower`, `reverse`, `concat`, `length`). Ceiling for what pure native + * can do. + * - '''Comet (user UDFs, dispatcher disabled)''': user UDFs with + * `codegenDispatch.mode=disabled`. `CometScalaUDF.convert` returns `None`, the ScalaUDF's + * Project falls back to Spark. This is the state before the dispatcher landed: any user UDF + * loses Comet acceleration on the whole hosting operator. + * - '''Comet (user UDFs, codegen dispatch)''': user UDFs with the dispatcher forced on. One + * Janino-compiled kernel per (tree, input schema) handles the whole composition in one JNI + * hop. + * + * Story the numbers should tell: dispatcher (mode 4) tracks native (mode 2) and beats + * dispatcher-disabled (mode 3) by the cost of the Spark fallback / ColumnarToRow hand-off. + * + * To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 \ + * make benchmark-org.apache.spark.sql.benchmark.CometScalaUDFCompositionBenchmark + * }}} + */ +object CometScalaUDFCompositionBenchmark extends CometBenchmarkBase { + + private def registerThreeLevelUdfs(): Unit = { + spark.udf.register("lvl1_upper", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lvl2_reverse", (s: String) => if (s == null) null else s.reverse) + spark.udf.register("lvl3_length", (s: String) => if (s == null) -1 else s.length) + } + + private def registerMultiColUdfs(): Unit = { + spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase) + spark.udf.register( + "joinU", + (a: String, b: String) => if (a == null || b == null) null else s"$a-$b") + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("scalaudf composition", 1024 * 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + registerThreeLevelUdfs() + runBenchmark("three-level composition: length(reverse(upper(c1)))") { + runModes( + name = "three-level", + cardinality = v, + nativeQuery = "SELECT length(reverse(upper(c1))) FROM parquetV1Table", + udfQuery = "SELECT lvl3_length(lvl2_reverse(lvl1_upper(c1))) FROM parquetV1Table") + } + } + } + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + "SELECT REPEAT(CAST(value AS STRING), 10) AS c1, " + + s"CAST(value AS STRING) AS c2 FROM $tbl")) + + registerMultiColUdfs() + runBenchmark("multi-col composition: concat(upper(c1), '-', lower(c2))") { + runModes( + name = "multi-col", + cardinality = v, + nativeQuery = "SELECT concat(upper(c1), '-', lower(c2)) FROM parquetV1Table", + udfQuery = "SELECT joinU(upperU(c1), lowerU(c2)) FROM parquetV1Table") + } + } + } + + // Aggregate shape: SUM over the composition output. Picks up the cost of "dispatcher + // disabled" breaking the columnar pipeline around an aggregate, not just the Project + // itself. When the dispatcher is off, the Project falls back to Spark, which typically + // drags the surrounding HashAggregate off Comet's columnar path too (ColumnarToRow hand-off + // plus Spark's row-based aggregate). When the dispatcher is on, scan -> project -> agg + // stays columnar end to end. + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + registerThreeLevelUdfs() + runBenchmark("agg over composition: SUM(length(reverse(upper(c1))))") { + runModes( + name = "agg-over-composition", + cardinality = v, + nativeQuery = "SELECT SUM(length(reverse(upper(c1)))) FROM parquetV1Table", + udfQuery = + "SELECT SUM(lvl3_length(lvl2_reverse(lvl1_upper(c1)))) FROM parquetV1Table") + } + } + } + } + } + + private def runModes( + name: String, + cardinality: Long, + nativeQuery: String, + udfQuery: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(udfQuery).noop() + } + } + + // Pure Comet-native rewrite of the composition using built-ins. Ceiling for native perf. + // Case conversion is enabled because upper/lower are in the tree. + benchmark.addCase("Comet (native built-ins)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") { + spark.sql(nativeQuery).noop() + } + } + + // User UDFs with dispatcher disabled. The ScalaUDF serde returns None, the hosting Project + // falls back to Spark. State of the world before the dispatcher landed: any ScalaUDF in a + // query sinks the containing operator. + benchmark.addCase("Comet (user UDFs, dispatcher disabled)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { + spark.sql(udfQuery).noop() + } + } + + // User UDFs through the codegen dispatcher. One Janino-compiled kernel for the whole tree, + // one JNI hop per batch. + benchmark.addCase("Comet (user UDFs, codegen dispatch)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") { + spark.sql(udfQuery).noop() + } + } + + benchmark.run() + } +}