Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.iceberg.BaseMetadataTable;
import org.apache.iceberg.BaseTable;
Expand All @@ -28,6 +29,8 @@
import org.apache.iceberg.IncrementalAppendScan;
import org.apache.iceberg.MetricsConfig;
import org.apache.iceberg.MetricsModes;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
Expand All @@ -41,12 +44,15 @@
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkAggregates;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.spark.TimeTravel;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.Pair;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -157,6 +163,10 @@ public boolean pushAggregation(Aggregation aggregation) {
return false;
}

if (aggregation.groupByExpressions().length > 0) {
return pushGroupByAggregation(aggregation, expressions);
}

try (CloseableIterable<FileScanTask> fileScanTasks = planFilesWithStats()) {
for (FileScanTask task : fileScanTasks) {
if (!task.deletes().isEmpty()) {
Expand Down Expand Up @@ -186,6 +196,214 @@ public boolean pushAggregation(Aggregation aggregation) {
return true;
}

/**
* Push down aggregation with GROUP BY on identity partition columns. When all GROUP BY columns
* are identity partition fields, aggregates can be computed from file metadata grouped by
* partition values, avoiding reading any data files.
*/
private boolean pushGroupByAggregation(
Aggregation aggregation, List<BoundAggregate<?, ?>> boundAggregates) {
Schema tableSchema = table().schema();

// resolve GROUP BY columns to source field IDs (not positions, for spec evolution safety)
List<Integer> groupBySourceIds = Lists.newArrayList();
List<Types.NestedField> groupByFields = Lists.newArrayList();
if (!resolveGroupByFields(aggregation, tableSchema, groupBySourceIds, groupByFields)) {
return false;
}

Map<List<Object>, AggregateEvaluator> evaluatorsByPartition =
groupFilesByPartition(groupBySourceIds, boundAggregates);
if (evaluatorsByPartition == null) {
return false;
}

localScan = buildGroupedLocalScan(groupByFields, evaluatorsByPartition);
return localScan != null;
}

private boolean resolveGroupByFields(
Aggregation aggregation,
Schema tableSchema,
List<Integer> groupBySourceIds,
List<Types.NestedField> groupByFields) {
PartitionSpec currentSpec = table().spec();
for (org.apache.spark.sql.connector.expressions.Expression groupByExpr :
aggregation.groupByExpressions()) {
String colName =
SparkUtil.toColumnName(
(org.apache.spark.sql.connector.expressions.NamedReference) groupByExpr);
Types.NestedField sourceField = tableSchema.findField(colName);
if (sourceField == null) {
LOG.info("Skipping grouped aggregate pushdown: cannot find field {}", colName);
return false;
}

// verify the field is an identity partition in the current spec
if (findIdentityPartitionPosition(currentSpec, sourceField.fieldId()) < 0) {
LOG.info(
"Skipping grouped aggregate pushdown: {} is not an identity partition field", colName);
return false;
}

groupBySourceIds.add(sourceField.fieldId());
groupByFields.add(sourceField);
}

return true;
}

private Map<List<Object>, AggregateEvaluator> groupFilesByPartition(
List<Integer> groupBySourceIds, List<BoundAggregate<?, ?>> boundAggregates) {
Map<List<Object>, AggregateEvaluator> evaluatorsByPartition = Maps.newLinkedHashMap();

try (CloseableIterable<FileScanTask> fileScanTasks = planFilesWithStats()) {
for (FileScanTask task : fileScanTasks) {
if (!task.deletes().isEmpty()) {
LOG.info("Skipping grouped aggregate pushdown: detected row level deletes");
return null;
}

// resolve partition values using the file's own spec (handles spec evolution)
PartitionSpec fileSpec = table().specs().get(task.file().specId());
StructLike partition = task.file().partition();
List<Object> key = Lists.newArrayListWithCapacity(groupBySourceIds.size());

for (int sourceId : groupBySourceIds) {
int pos = findIdentityPartitionPosition(fileSpec, sourceId);
if (pos < 0) {
LOG.info(
"Skipping grouped aggregate pushdown: field {} not in spec {}",
sourceId,
fileSpec.specId());
return null;
}
key.add(partition.get(pos, Object.class));
}

evaluatorsByPartition
.computeIfAbsent(key, k -> AggregateEvaluator.create(boundAggregates))
.update(task.file());
}
} catch (IOException e) {
LOG.info("Skipping grouped aggregate pushdown: ", e);
return null;
}

if (evaluatorsByPartition.isEmpty()) {
return null;
}

for (AggregateEvaluator evaluator : evaluatorsByPartition.values()) {
if (!evaluator.allAggregatorsValid()) {
return null;
}
}

return evaluatorsByPartition;
}

private SparkLocalScan buildGroupedLocalScan(
List<Types.NestedField> groupByFields,
Map<List<Object>, AggregateEvaluator> evaluatorsByPartition) {
AggregateEvaluator firstEvaluator = evaluatorsByPartition.values().iterator().next();
List<Types.NestedField> resultFields = Lists.newArrayList();
int fieldId = 0;

for (Types.NestedField field : groupByFields) {
resultFields.add(Types.NestedField.optional(fieldId++, field.name(), field.type()));
}

for (Types.NestedField field : firstEvaluator.resultType().fields()) {
resultFields.add(Types.NestedField.optional(fieldId++, field.name(), field.type()));
}

Types.StructType resultType = Types.StructType.of(resultFields);
List<InternalRow> resultRows = Lists.newArrayList();

for (Map.Entry<List<Object>, AggregateEvaluator> entry : evaluatorsByPartition.entrySet()) {
List<Object> partitionValues = entry.getKey();
StructLike aggResult = entry.getValue().result();

Object[] combined = new Object[resultFields.size()];
for (int i = 0; i < partitionValues.size(); i++) {
combined[i] = partitionValues.get(i);
}

for (int i = 0; i < aggResult.size(); i++) {
combined[partitionValues.size() + i] = aggResult.get(i, Object.class);
}

resultRows.add(new StructInternalRow(resultType).setStruct(new ArrayStructLike(combined)));
}

StructType pushedSchema = SparkSchemaUtil.convert(new Schema(resultFields));
return new SparkLocalScan(
table(), pushedSchema, resultRows.toArray(new InternalRow[0]), filters());
}

private int findIdentityPartitionPosition(PartitionSpec spec, int sourceFieldId) {
List<PartitionField> fields = spec.fields();
for (int i = 0; i < fields.size(); i++) {
PartitionField field = fields.get(i);
if (field.sourceId() == sourceFieldId && field.transform().isIdentity()) {
return i;
}
}

return -1;
}

private boolean allGroupByAreIdentityPartitionFields(Aggregation aggregation) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allGroupByAreIdentityPartitionFields() and resolveGroupByFields() look very similar except

  • allGroupByAreIdentityPartitionFields additionally checks instanceof NamedReference
  • resolveGroupByFields additionally collects field IDs and fields into output lists
    Can we merge these two?

Or maybe let canPushDownAggregation() allow group by and then have the checks in this merged method? What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Merged allGroupByAreIdentityPartitionFields into resolveGroupByFields. Removed the separate method and simplified canPushDownAggregation

PartitionSpec spec = table().spec();
Schema tableSchema = table().schema();

for (org.apache.spark.sql.connector.expressions.Expression groupByExpr :
aggregation.groupByExpressions()) {
if (!(groupByExpr instanceof org.apache.spark.sql.connector.expressions.NamedReference)) {
return false;
}

String colName =
SparkUtil.toColumnName(
(org.apache.spark.sql.connector.expressions.NamedReference) groupByExpr);
Types.NestedField sourceField = tableSchema.findField(colName);
if (sourceField == null) {
return false;
}

if (findIdentityPartitionPosition(spec, sourceField.fieldId()) < 0) {
return false;
}
}

return true;
}

private static class ArrayStructLike implements StructLike {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use AggregateEvaluator.ArrayStructLike instead? May have to make it package-private.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AggregateEvaluator.ArrayStructLike is private static in the api module. Since SparkScanBuilder is in spark module, I assume even package-private wouldn't help, we'd need to make it public. Kept the changes same to avoid API surface changes. Happy to follow up separately if preferred.

private final Object[] values;

ArrayStructLike(Object[] values) {
this.values = values;
}

@Override
public int size() {
return values.length;
}

@Override
@SuppressWarnings("unchecked")
public <T> T get(int pos, Class<T> javaClass) {
return (T) values[pos];
}

@Override
public <T> void set(int pos, T value) {
values[pos] = value;
}
}

private boolean canPushDownAggregation(Aggregation aggregation) {
if (!isMainTable()) {
return false;
Expand All @@ -195,12 +413,12 @@ private boolean canPushDownAggregation(Aggregation aggregation) {
return false;
}

// If group by expression is the same as the partition, the statistics information can still
// be used to calculate min/max/count, will enable aggregate push down in next phase.
// TODO: enable aggregate push down for partition col group by expression
if (aggregation.groupByExpressions().length > 0) {
LOG.info("Skipping aggregate pushdown: group by aggregation push down is not supported");
return false;
if (!allGroupByAreIdentityPartitionFields(aggregation)) {
LOG.info(
"Skipping aggregate pushdown: group by columns must all be identity partition fields");
return false;
}
}

return true;
Expand Down
Loading
Loading