diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 6423ee4076d9..64d6a9b3d7df 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.iceberg.BaseMetadataTable; import org.apache.iceberg.BaseTable; @@ -28,6 +29,8 @@ import org.apache.iceberg.IncrementalAppendScan; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.MetricsModes; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.ScanTask; import org.apache.iceberg.Schema; import org.apache.iceberg.Snapshot; @@ -41,12 +44,15 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkAggregates; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkTableUtil; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.spark.TimeTravel; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; import org.apache.iceberg.util.Pair; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.InternalRow; @@ -157,6 +163,10 @@ public boolean pushAggregation(Aggregation aggregation) { return false; } + if (aggregation.groupByExpressions().length > 0) { + return pushGroupByAggregation(aggregation, expressions); + } + try (CloseableIterable fileScanTasks = planFilesWithStats()) { for (FileScanTask task : fileScanTasks) { if (!task.deletes().isEmpty()) { @@ -186,6 +196,214 @@ public boolean pushAggregation(Aggregation aggregation) { return true; } + /** + * Push down aggregation with GROUP BY on identity partition columns. When all GROUP BY columns + * are identity partition fields, aggregates can be computed from file metadata grouped by + * partition values, avoiding reading any data files. + */ + private boolean pushGroupByAggregation( + Aggregation aggregation, List> boundAggregates) { + Schema tableSchema = table().schema(); + + // resolve GROUP BY columns to source field IDs (not positions, for spec evolution safety) + List groupBySourceIds = Lists.newArrayList(); + List groupByFields = Lists.newArrayList(); + if (!resolveGroupByFields(aggregation, tableSchema, groupBySourceIds, groupByFields)) { + return false; + } + + Map, AggregateEvaluator> evaluatorsByPartition = + groupFilesByPartition(groupBySourceIds, boundAggregates); + if (evaluatorsByPartition == null) { + return false; + } + + localScan = buildGroupedLocalScan(groupByFields, evaluatorsByPartition); + return localScan != null; + } + + private boolean resolveGroupByFields( + Aggregation aggregation, + Schema tableSchema, + List groupBySourceIds, + List groupByFields) { + PartitionSpec currentSpec = table().spec(); + for (org.apache.spark.sql.connector.expressions.Expression groupByExpr : + aggregation.groupByExpressions()) { + String colName = + SparkUtil.toColumnName( + (org.apache.spark.sql.connector.expressions.NamedReference) groupByExpr); + Types.NestedField sourceField = tableSchema.findField(colName); + if (sourceField == null) { + LOG.info("Skipping grouped aggregate pushdown: cannot find field {}", colName); + return false; + } + + // verify the field is an identity partition in the current spec + if (findIdentityPartitionPosition(currentSpec, sourceField.fieldId()) < 0) { + LOG.info( + "Skipping grouped aggregate pushdown: {} is not an identity partition field", colName); + return false; + } + + groupBySourceIds.add(sourceField.fieldId()); + groupByFields.add(sourceField); + } + + return true; + } + + private Map, AggregateEvaluator> groupFilesByPartition( + List groupBySourceIds, List> boundAggregates) { + Map, AggregateEvaluator> evaluatorsByPartition = Maps.newLinkedHashMap(); + + try (CloseableIterable fileScanTasks = planFilesWithStats()) { + for (FileScanTask task : fileScanTasks) { + if (!task.deletes().isEmpty()) { + LOG.info("Skipping grouped aggregate pushdown: detected row level deletes"); + return null; + } + + // resolve partition values using the file's own spec (handles spec evolution) + PartitionSpec fileSpec = table().specs().get(task.file().specId()); + StructLike partition = task.file().partition(); + List key = Lists.newArrayListWithCapacity(groupBySourceIds.size()); + + for (int sourceId : groupBySourceIds) { + int pos = findIdentityPartitionPosition(fileSpec, sourceId); + if (pos < 0) { + LOG.info( + "Skipping grouped aggregate pushdown: field {} not in spec {}", + sourceId, + fileSpec.specId()); + return null; + } + key.add(partition.get(pos, Object.class)); + } + + evaluatorsByPartition + .computeIfAbsent(key, k -> AggregateEvaluator.create(boundAggregates)) + .update(task.file()); + } + } catch (IOException e) { + LOG.info("Skipping grouped aggregate pushdown: ", e); + return null; + } + + if (evaluatorsByPartition.isEmpty()) { + return null; + } + + for (AggregateEvaluator evaluator : evaluatorsByPartition.values()) { + if (!evaluator.allAggregatorsValid()) { + return null; + } + } + + return evaluatorsByPartition; + } + + private SparkLocalScan buildGroupedLocalScan( + List groupByFields, + Map, AggregateEvaluator> evaluatorsByPartition) { + AggregateEvaluator firstEvaluator = evaluatorsByPartition.values().iterator().next(); + List resultFields = Lists.newArrayList(); + int fieldId = 0; + + for (Types.NestedField field : groupByFields) { + resultFields.add(Types.NestedField.optional(fieldId++, field.name(), field.type())); + } + + for (Types.NestedField field : firstEvaluator.resultType().fields()) { + resultFields.add(Types.NestedField.optional(fieldId++, field.name(), field.type())); + } + + Types.StructType resultType = Types.StructType.of(resultFields); + List resultRows = Lists.newArrayList(); + + for (Map.Entry, AggregateEvaluator> entry : evaluatorsByPartition.entrySet()) { + List partitionValues = entry.getKey(); + StructLike aggResult = entry.getValue().result(); + + Object[] combined = new Object[resultFields.size()]; + for (int i = 0; i < partitionValues.size(); i++) { + combined[i] = partitionValues.get(i); + } + + for (int i = 0; i < aggResult.size(); i++) { + combined[partitionValues.size() + i] = aggResult.get(i, Object.class); + } + + resultRows.add(new StructInternalRow(resultType).setStruct(new ArrayStructLike(combined))); + } + + StructType pushedSchema = SparkSchemaUtil.convert(new Schema(resultFields)); + return new SparkLocalScan( + table(), pushedSchema, resultRows.toArray(new InternalRow[0]), filters()); + } + + private int findIdentityPartitionPosition(PartitionSpec spec, int sourceFieldId) { + List fields = spec.fields(); + for (int i = 0; i < fields.size(); i++) { + PartitionField field = fields.get(i); + if (field.sourceId() == sourceFieldId && field.transform().isIdentity()) { + return i; + } + } + + return -1; + } + + private boolean allGroupByAreIdentityPartitionFields(Aggregation aggregation) { + PartitionSpec spec = table().spec(); + Schema tableSchema = table().schema(); + + for (org.apache.spark.sql.connector.expressions.Expression groupByExpr : + aggregation.groupByExpressions()) { + if (!(groupByExpr instanceof org.apache.spark.sql.connector.expressions.NamedReference)) { + return false; + } + + String colName = + SparkUtil.toColumnName( + (org.apache.spark.sql.connector.expressions.NamedReference) groupByExpr); + Types.NestedField sourceField = tableSchema.findField(colName); + if (sourceField == null) { + return false; + } + + if (findIdentityPartitionPosition(spec, sourceField.fieldId()) < 0) { + return false; + } + } + + return true; + } + + private static class ArrayStructLike implements StructLike { + private final Object[] values; + + ArrayStructLike(Object[] values) { + this.values = values; + } + + @Override + public int size() { + return values.length; + } + + @Override + @SuppressWarnings("unchecked") + public T get(int pos, Class javaClass) { + return (T) values[pos]; + } + + @Override + public void set(int pos, T value) { + values[pos] = value; + } + } + private boolean canPushDownAggregation(Aggregation aggregation) { if (!isMainTable()) { return false; @@ -195,12 +413,12 @@ private boolean canPushDownAggregation(Aggregation aggregation) { return false; } - // If group by expression is the same as the partition, the statistics information can still - // be used to calculate min/max/count, will enable aggregate push down in next phase. - // TODO: enable aggregate push down for partition col group by expression if (aggregation.groupByExpressions().length > 0) { - LOG.info("Skipping aggregate pushdown: group by aggregation push down is not supported"); - return false; + if (!allGroupByAreIdentityPartitionFields(aggregation)) { + LOG.info( + "Skipping aggregate pushdown: group by columns must all be identity partition fields"); + return false; + } } return true; diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java index 6eac5474afde..38c63a5444db 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -568,11 +568,9 @@ public void testAggregationPushdownOnBucketedColumn() { sql( "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT) USING iceberg PARTITIONED BY (bucket(8, id))", tableName); - sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", tableName); sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", tableName); sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName); - String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s"; String aggField = "id"; assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 2L, 1L); @@ -909,4 +907,183 @@ public void testAggregatePushDownForIncrementalScan() { assertEquals( "min/max/count push down", expected2, rowsToJava(unboundedPushdownDs.collectAsList())); } + + @TestTemplate + public void testGroupByIdentityPartitionColumnCountPushDown() { + sql( + "CREATE TABLE %s (id LONG, data STRING, category STRING) USING iceberg PARTITIONED BY (category)", + tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, 'a', 'fruit'), (2, 'b', 'fruit'), (3, 'c', 'fruit')," + + "(4, 'd', 'veggie'), (5, 'e', 'veggie')," + + "(6, 'f', 'dairy')", + tableName); + + String select = "SELECT category, count(*) FROM %s GROUP BY category"; + + List actual = sql(select, tableName); + assertThat(actual).hasSize(3); + + actual.sort((a, b) -> ((String) a[0]).compareTo((String) b[0])); + + assertEquals( + "group by partition count push down", + Lists.newArrayList( + new Object[] {"dairy", 1L}, new Object[] {"fruit", 3L}, new Object[] {"veggie", 2L}), + actual); + } + + @TestTemplate + public void testGroupByIdentityPartitionColumnWithMinMax() { + sql( + "CREATE TABLE %s (id LONG, price DOUBLE, region STRING) USING iceberg PARTITIONED BY (region)", + tableName); + sql( + "INSERT INTO TABLE %s VALUES " + + "(1, 10.0, 'east'), (2, 20.0, 'east'), (3, 30.0, 'east')," + + "(4, 5.0, 'west'), (5, 50.0, 'west')," + + "(6, 100.0, 'north')", + tableName); + + String select = "SELECT region, count(*), max(price), min(price) FROM %s GROUP BY region"; + + List actual = sql(select, tableName); + assertThat(actual).hasSize(3); + + actual.sort((a, b) -> ((String) a[0]).compareTo((String) b[0])); + + assertEquals( + "group by partition with min/max push down", + Lists.newArrayList( + new Object[] {"east", 3L, 30.0, 10.0}, + new Object[] {"north", 1L, 100.0, 100.0}, + new Object[] {"west", 2L, 50.0, 5.0}), + actual); + } + + @TestTemplate + public void testGroupByNonPartitionColumnNoPushDown() { + sql( + "CREATE TABLE %s (id LONG, data STRING, category STRING) USING iceberg PARTITIONED BY (category)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a', 'x'), (2, 'b', 'x'), (3, 'a', 'y')", tableName); + + String select = "SELECT data, count(*) FROM %s GROUP BY data"; + + List explain = sql("EXPLAIN " + select, tableName); + String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); + + assertThat(explainString) + .as("non-partition group by should not be pushed down to LocalTableScan") + .doesNotContain("localtablescan"); + + List actual = sql(select, tableName); + assertThat(actual).hasSize(2); + } + + @TestTemplate + public void testGroupByPartitionColumnMultipleInserts() { + sql( + "CREATE TABLE %s (id LONG, amount DOUBLE, city STRING) USING iceberg PARTITIONED BY (city)", + tableName); + sql("INSERT INTO TABLE %s VALUES (1, 100.0, 'NYC'), (2, 200.0, 'NYC')", tableName); + sql("INSERT INTO TABLE %s VALUES (3, 50.0, 'NYC'), (4, 300.0, 'SF')", tableName); + sql("INSERT INTO TABLE %s VALUES (5, 150.0, 'SF'), (6, 75.0, 'LA')", tableName); + + String select = "SELECT city, count(*), max(amount), min(amount) FROM %s GROUP BY city"; + + List actual = sql(select, tableName); + assertThat(actual).hasSize(3); + + actual.sort((a, b) -> ((String) a[0]).compareTo((String) b[0])); + + assertEquals( + "group by partition with multiple files per partition", + Lists.newArrayList( + new Object[] {"LA", 1L, 75.0, 75.0}, + new Object[] {"NYC", 3L, 200.0, 50.0}, + new Object[] {"SF", 2L, 300.0, 150.0}), + actual); + } + + @TestTemplate + public void testGroupByPartitionColumnAfterSpecEvolution() { + sql( + "CREATE TABLE %s (id LONG, amount DOUBLE, region STRING) USING iceberg PARTITIONED BY (region)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 10.0, 'east'), (2, 20.0, 'west')", tableName); + + validationCatalog.loadTable(tableIdent).updateSpec().addField("id").commit(); + + sql("INSERT INTO TABLE %s VALUES (3, 30.0, 'east'), (4, 40.0, 'west')", tableName); + + String select = "SELECT region, count(*), max(amount) FROM %s GROUP BY region"; + + List actual = sql(select, tableName); + assertThat(actual).hasSize(2); + + actual.sort((a, b) -> ((String) a[0]).compareTo((String) b[0])); + + assertEquals( + "group by partition after spec evolution", + Lists.newArrayList(new Object[] {"east", 2L, 30.0}, new Object[] {"west", 2L, 40.0}), + actual); + } + + @TestTemplate + public void testGroupByPartitionRemovedInNewSpecFallsBack() { + sql( + "CREATE TABLE %s (id LONG, amount DOUBLE, region STRING) USING iceberg PARTITIONED BY (region)", + tableName); + + sql("INSERT INTO TABLE %s VALUES (1, 10.0, 'east'), (2, 20.0, 'west')", tableName); + + validationCatalog + .loadTable(tableIdent) + .updateSpec() + .removeField("region") + .addField("id") + .commit(); + + sql("INSERT INTO TABLE %s VALUES (3, 30.0, 'east'), (4, 40.0, 'west')", tableName); + + String select = "SELECT region, count(*) FROM %s GROUP BY region"; + + List actual = sql(select, tableName); + actual.sort((a, b) -> ((String) a[0]).compareTo((String) b[0])); + assertEquals( + "group by with removed partition field", + Lists.newArrayList(new Object[] {"east", 2L}, new Object[] {"west", 2L}), + actual); + } + + @TestTemplate + public void testGroupByPartitionColumnWithNullValues() { + sql( + "CREATE TABLE %s (id LONG, amount DOUBLE, region STRING) USING iceberg PARTITIONED BY (region)", + tableName); + + sql( + "INSERT INTO TABLE %s VALUES (1, 10.0, 'east'), (2, 20.0, null), (3, 30.0, 'east'), (4, 40.0, null)", + tableName); + + String select = "SELECT region, count(*), max(amount) FROM %s GROUP BY region"; + + List actual = sql(select, tableName); + assertThat(actual).hasSize(2); + + actual.sort( + (a, b) -> { + if (a[0] == null) return -1; + if (b[0] == null) return 1; + return ((String) a[0]).compareTo((String) b[0]); + }); + + assertEquals( + "group by partition with null values", + Lists.newArrayList(new Object[] {null, 2L, 40.0}, new Object[] {"east", 2L, 30.0}), + actual); + } }