diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java index 31ef246eb328..152ed32469da 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java @@ -51,13 +51,20 @@ public class AggregationOperator extends BaseOperator { private final long _numTotalDocs; private int _numDocsScanned = 0; + protected final Object[] _preAggregatedResults; public AggregationOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs) { + this(queryContext, aggregationInfo, numTotalDocs, null); + } + + public AggregationOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs, + Object[] preAggregatedResults) { _queryContext = queryContext; _aggregationFunctions = queryContext.getAggregationFunctions(); _projectOperator = aggregationInfo.getProjectOperator(); _useStarTree = aggregationInfo.isUseStarTree(); _numTotalDocs = numTotalDocs; + _preAggregatedResults = preAggregatedResults; } @Override @@ -67,7 +74,7 @@ protected AggregationResultsBlock getNextBlock() { if (_useStarTree) { aggregationExecutor = new StarTreeAggregationExecutor(_aggregationFunctions); } else { - aggregationExecutor = new DefaultAggregationExecutor(_aggregationFunctions); + aggregationExecutor = new DefaultAggregationExecutor(_aggregationFunctions, _preAggregatedResults); } ValueBlock valueBlock; while ((valueBlock = _projectOperator.nextBlock()) != null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java index 783fd85129c8..4a34d600f14f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java @@ -18,43 +18,17 @@ */ package org.apache.pinot.core.operator.query; -import com.clearspring.analytics.stream.cardinality.HyperLogLog; -import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus; -import com.dynatrace.hash4j.distinctcount.UltraLogLog; -import com.google.common.base.Preconditions; -import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; -import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; -import it.unimi.dsi.fastutil.ints.IntOpenHashSet; -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; -import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Objects; -import java.util.Set; -import org.apache.pinot.core.common.ObjectSerDeUtils; import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.ExecutionStatistics; import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountHLLAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountHLLPlusAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountOffHeapAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLPlusAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLPlusAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountSmartULLAggregationFunction; -import org.apache.pinot.core.query.aggregation.function.DistinctCountULLAggregationFunction; +import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.request.context.QueryContext; -import org.apache.pinot.segment.local.customobject.MinMaxRangePair; -import org.apache.pinot.segment.local.utils.UltraLogLogUtils; import org.apache.pinot.segment.spi.datasource.DataSource; -import org.apache.pinot.segment.spi.index.reader.Dictionary; -import org.apache.pinot.spi.data.FieldSpec.DataType; -import org.apache.pinot.spi.query.QueryThreadContext; -import org.apache.pinot.spi.utils.ByteArray; /** @@ -74,7 +48,7 @@ */ @SuppressWarnings("rawtypes") public class NonScanBasedAggregationOperator extends BaseOperator { - private static final String EXPLAIN_NAME = "AGGREGATE_NO_SCAN"; + public static final String EXPLAIN_NAME = "AGGREGATE_NO_SCAN"; private final QueryContext _queryContext; private final AggregationFunction[] _aggregationFunctions; @@ -95,97 +69,8 @@ protected AggregationResultsBlock getNextBlock() { AggregationFunction aggregationFunction = _aggregationFunctions[i]; // note that dataSource will be null for COUNT, sp do not interact with it until it's known this isn't a COUNT DataSource dataSource = _dataSources[i]; - Object result; - switch (aggregationFunction.getType()) { - case COUNT: - result = (long) _numTotalDocs; - break; - case MIN: - case MINMV: - result = getMinValueNumeric(dataSource); - break; - case MINLONG: - result = getMinValueLong(dataSource); - break; - case MINSTRING: - assert dataSource.getDictionary() != null; - result = dataSource.getDictionary().getMinVal(); - break; - case MAX: - case MAXMV: - result = getMaxValueNumeric(dataSource); - break; - case MAXLONG: - result = getMaxValueLong(dataSource); - break; - case MAXSTRING: - assert dataSource.getDictionary() != null; - result = dataSource.getDictionary().getMaxVal(); - break; - case MINMAXRANGE: - case MINMAXRANGEMV: - result = new MinMaxRangePair(getMinValueNumeric(dataSource), getMaxValueNumeric(dataSource)); - break; - case DISTINCTCOUNT: - case DISTINCTSUM: - case DISTINCTAVG: - case DISTINCTCOUNTMV: - case DISTINCTSUMMV: - case DISTINCTAVGMV: - result = getDistinctValueSet(Objects.requireNonNull(dataSource.getDictionary())); - break; - case DISTINCTCOUNTOFFHEAP: - result = ((DistinctCountOffHeapAggregationFunction) aggregationFunction).extractAggregationResult( - Objects.requireNonNull(dataSource.getDictionary())); - break; - case DISTINCTCOUNTHLL: - case DISTINCTCOUNTHLLMV: - result = getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountHLLAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTRAWHLL: - case DISTINCTCOUNTRAWHLLMV: - result = getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()), - ((DistinctCountRawHLLAggregationFunction) aggregationFunction).getDistinctCountHLLAggregationFunction()); - break; - case DISTINCTCOUNTHLLPLUS: - case DISTINCTCOUNTHLLPLUSMV: - result = getDistinctCountHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountHLLPlusAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTRAWHLLPLUS: - case DISTINCTCOUNTRAWHLLPLUSMV: - result = getDistinctCountHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), - ((DistinctCountRawHLLPlusAggregationFunction) aggregationFunction) - .getDistinctCountHLLPlusAggregationFunction()); - break; - case SEGMENTPARTITIONEDDISTINCTCOUNT: - result = (long) Objects.requireNonNull(dataSource.getDictionary()).length(); - break; - case DISTINCTCOUNTSMARTHLL: - result = getDistinctCountSmartHLLResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountSmartHLLAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTSMARTHLLPLUS: - result = getDistinctCountSmartHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountSmartHLLPlusAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTULL: - result = getDistinctCountULLResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountULLAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTSMARTULL: - result = getDistinctCountSmartULLResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountSmartULLAggregationFunction) aggregationFunction); - break; - case DISTINCTCOUNTRAWULL: - result = getDistinctCountULLResult(Objects.requireNonNull(dataSource.getDictionary()), - (DistinctCountULLAggregationFunction) aggregationFunction); - break; - default: - throw new IllegalStateException( - "Non-scan based aggregation operator does not support function type: " + aggregationFunction.getType()); - } + Object result = AggregationFunctionUtils.getAggregationResult(aggregationFunction, dataSource, + _numTotalDocs, EXPLAIN_NAME); aggregationResults.add(result); } @@ -193,229 +78,6 @@ protected AggregationResultsBlock getNextBlock() { return new AggregationResultsBlock(_aggregationFunctions, aggregationResults, _queryContext); } - private static Double getMinValueNumeric(DataSource dataSource) { - Dictionary dictionary = dataSource.getDictionary(); - if (dictionary != null) { - return toDouble(dictionary.getMinVal()); - } - return toDouble(dataSource.getDataSourceMetadata().getMinValue()); - } - - private static Long getMinValueLong(DataSource dataSource) { - DataType dataType = dataSource.getDataSourceMetadata().getDataType().getStoredType(); - Preconditions.checkArgument( - dataType == DataType.LONG || dataType == DataType.INT, - "MINLONG aggregation function can only be applied to columns of integer types"); - Dictionary dictionary = dataSource.getDictionary(); - if (dictionary != null) { - return ((Number) dictionary.getMinVal()).longValue(); - } - return ((Number) dataSource.getDataSourceMetadata().getMinValue()).longValue(); - } - - private static Double getMaxValueNumeric(DataSource dataSource) { - Dictionary dictionary = dataSource.getDictionary(); - if (dictionary != null) { - return toDouble(dictionary.getMaxVal()); - } - return toDouble(dataSource.getDataSourceMetadata().getMaxValue()); - } - - private static Long getMaxValueLong(DataSource dataSource) { - DataType dataType = dataSource.getDataSourceMetadata().getDataType().getStoredType(); - Preconditions.checkArgument( - dataType == DataType.LONG || dataType == DataType.INT, - "MAXLONG aggregation function can only be applied to columns of integer types"); - Dictionary dictionary = dataSource.getDictionary(); - if (dictionary != null) { - return ((Number) dictionary.getMaxVal()).longValue(); - } - return ((Number) dataSource.getDataSourceMetadata().getMaxValue()).longValue(); - } - - private static Double toDouble(Comparable value) { - if (value instanceof Double) { - return (Double) value; - } else if (value instanceof Number) { - return ((Number) value).doubleValue(); - } else { - return Double.parseDouble(value.toString()); - } - } - - private static Set getDistinctValueSet(Dictionary dictionary) { - int dictionarySize = dictionary.length(); - switch (dictionary.getValueType()) { - case INT: - IntOpenHashSet intSet = new IntOpenHashSet(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - intSet.add(dictionary.getIntValue(dictId)); - } - return intSet; - case LONG: - LongOpenHashSet longSet = new LongOpenHashSet(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - longSet.add(dictionary.getLongValue(dictId)); - } - return longSet; - case FLOAT: - FloatOpenHashSet floatSet = new FloatOpenHashSet(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - floatSet.add(dictionary.getFloatValue(dictId)); - } - return floatSet; - case DOUBLE: - DoubleOpenHashSet doubleSet = new DoubleOpenHashSet(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - doubleSet.add(dictionary.getDoubleValue(dictId)); - } - return doubleSet; - case STRING: - ObjectOpenHashSet stringSet = new ObjectOpenHashSet<>(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - stringSet.add(dictionary.getStringValue(dictId)); - } - return stringSet; - case BYTES: - ObjectOpenHashSet bytesSet = new ObjectOpenHashSet<>(dictionarySize); - for (int dictId = 0; dictId < dictionarySize; dictId++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, EXPLAIN_NAME); - bytesSet.add(new ByteArray(dictionary.getBytesValue(dictId))); - } - return bytesSet; - default: - throw new IllegalStateException(); - } - } - - private static HyperLogLog getDistinctValueHLL(Dictionary dictionary, int log2m) { - HyperLogLog hll = new HyperLogLog(log2m); - int length = dictionary.length(); - for (int i = 0; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - hll.offer(dictionary.get(i)); - } - return hll; - } - - private static UltraLogLog getDistinctValueULL(Dictionary dictionary, int p) { - UltraLogLog ull = UltraLogLog.create(p); - int length = dictionary.length(); - for (int i = 0; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - Object value = dictionary.get(i); - UltraLogLogUtils.hashObject(value).ifPresent(ull::add); - } - return ull; - } - - private static HyperLogLogPlus getDistinctValueHLLPlus(Dictionary dictionary, int p, int sp) { - HyperLogLogPlus hllPlus = new HyperLogLogPlus(p, sp); - int length = dictionary.length(); - for (int i = 0; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - hllPlus.offer(dictionary.get(i)); - } - return hllPlus; - } - - private static HyperLogLog getDistinctCountHLLResult(Dictionary dictionary, - DistinctCountHLLAggregationFunction function) { - if (dictionary.getValueType() == DataType.BYTES) { - // Treat BYTES value as serialized HyperLogLog - try { - QueryThreadContext.checkTerminationAndSampleUsage(EXPLAIN_NAME); - HyperLogLog hll = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(dictionary.getBytesValue(0)); - int length = dictionary.length(); - for (int i = 1; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - hll.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(dictionary.getBytesValue(i))); - } - return hll; - } catch (Exception e) { - throw new RuntimeException("Caught exception while merging HyperLogLogs", e); - } - } else { - return getDistinctValueHLL(dictionary, function.getLog2m()); - } - } - - private static HyperLogLogPlus getDistinctCountHLLPlusResult(Dictionary dictionary, - DistinctCountHLLPlusAggregationFunction function) { - if (dictionary.getValueType() == DataType.BYTES) { - // Treat BYTES value as serialized HyperLogLogPlus - try { - QueryThreadContext.checkTerminationAndSampleUsage(EXPLAIN_NAME); - HyperLogLogPlus hllplus = ObjectSerDeUtils.HYPER_LOG_LOG_PLUS_SER_DE.deserialize(dictionary.getBytesValue(0)); - int length = dictionary.length(); - for (int i = 1; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - hllplus.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_PLUS_SER_DE.deserialize(dictionary.getBytesValue(i))); - } - return hllplus; - } catch (Exception e) { - throw new RuntimeException("Caught exception while merging HyperLogLogPluses", e); - } - } else { - return getDistinctValueHLLPlus(dictionary, function.getP(), function.getSp()); - } - } - - private static Object getDistinctCountSmartHLLResult(Dictionary dictionary, - DistinctCountSmartHLLAggregationFunction function) { - if (dictionary.length() > function.getThreshold()) { - // Store values into a HLL when the dictionary size exceeds the conversion threshold - return getDistinctValueHLL(dictionary, function.getLog2m()); - } else { - return getDistinctValueSet(dictionary); - } - } - - private static Object getDistinctCountSmartHLLPlusResult(Dictionary dictionary, - DistinctCountSmartHLLPlusAggregationFunction function) { - if (dictionary.length() > function.getThreshold()) { - // Store values into a HLLPlus when the dictionary size exceeds the conversion threshold - return getDistinctValueHLLPlus(dictionary, function.getP(), function.getSp()); - } else { - return getDistinctValueSet(dictionary); - } - } - - private static UltraLogLog getDistinctCountULLResult(Dictionary dictionary, - DistinctCountULLAggregationFunction function) { - if (dictionary.getValueType() == DataType.BYTES) { - // Treat BYTES value as serialized UltraLogLog and merge - try { - QueryThreadContext.checkTerminationAndSampleUsage(EXPLAIN_NAME); - UltraLogLog ull = ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize(dictionary.getBytesValue(0)); - int length = dictionary.length(); - for (int i = 1; i < length; i++) { - QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, EXPLAIN_NAME); - ull.add(ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize(dictionary.getBytesValue(i))); - } - return ull; - } catch (Exception e) { - throw new RuntimeException("Caught exception while merging UltraLogLogs", e); - } - } else { - return getDistinctValueULL(dictionary, function.getP()); - } - } - - private static Object getDistinctCountSmartULLResult(Dictionary dictionary, - DistinctCountSmartULLAggregationFunction function) { - if (dictionary.length() > function.getThreshold()) { - return getDistinctValueULL(dictionary, function.getP()); - } else { - return getDistinctValueSet(dictionary); - } - } - @Override public String toExplainString() { return EXPLAIN_NAME; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index 055a366749be..9527c6901721 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -115,17 +115,38 @@ public Operator buildNonFilteredAggOperator() { boolean hasNullValues = _queryContext.isNullHandlingEnabled() && hasNullValues(aggregationFunctions); if (!hasNullValues) { + DataSource[] dataSources = new DataSource[aggregationFunctions.length]; + for (int i = 0; i < aggregationFunctions.length; i++) { + List inputExpressions = aggregationFunctions[i].getInputExpressions(); + if (!inputExpressions.isEmpty()) { + String column = ((ExpressionContext) inputExpressions.get(0)).getIdentifier(); + dataSources[i] = _indexSegment.getDataSource(column, _queryContext.getSchema()); + } + } + // Priority 2: Check if non-scan based aggregation is feasible if (filterOperator.isResultMatchingAll() && isFitForNonScanBasedPlan()) { - DataSource[] dataSources = new DataSource[aggregationFunctions.length]; + return new NonScanBasedAggregationOperator(_queryContext, dataSources, numTotalDocs); + } + + if (filterOperator.isResultMatchingAll()) { + boolean anyResolved = false; + Object[] preAggregatedResults = new Object[aggregationFunctions.length]; for (int i = 0; i < aggregationFunctions.length; i++) { - List inputExpressions = aggregationFunctions[i].getInputExpressions(); - if (!inputExpressions.isEmpty()) { - String column = ((ExpressionContext) inputExpressions.get(0)).getIdentifier(); - dataSources[i] = _indexSegment.getDataSource(column, _queryContext.getSchema()); + Object resolved = AggregationFunctionUtils.getAggregationResult(aggregationFunctions[i], + dataSources[i], numTotalDocs, NonScanBasedAggregationOperator.EXPLAIN_NAME); + if (resolved != null) { + preAggregatedResults[i] = resolved; + anyResolved = true; } } - return new NonScanBasedAggregationOperator(_queryContext, dataSources, numTotalDocs); + + if (anyResolved) { + // build aggregation info for all functions (including those that cannot be resolved from metadata) + aggregationInfo = AggregationFunctionUtils.buildAggregationInfoWithoutStarTree(_segmentContext, _queryContext, + aggregationFunctions, filterOperator); + return new AggregationOperator(_queryContext, aggregationInfo, numTotalDocs, preAggregatedResults); + } } // Priority 3: Check if fast filtered count can be used diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java index d6c64c1cfb68..e0e471a4fee0 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java @@ -29,8 +29,10 @@ public class DefaultAggregationExecutor implements AggregationExecutor { protected final AggregationFunction[] _aggregationFunctions; protected final AggregationResultHolder[] _aggregationResultHolders; + protected final Object[] _preAggregatedResults; - public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions) { + public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions, Object[] preAggregatedResults) { + _preAggregatedResults = preAggregatedResults; _aggregationFunctions = aggregationFunctions; int numAggregationFunctions = aggregationFunctions.length; _aggregationResultHolders = new AggregationResultHolder[numAggregationFunctions]; @@ -39,11 +41,18 @@ public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions) { } } + public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions) { + this(aggregationFunctions, null); + } + @Override public void aggregate(ValueBlock valueBlock) { int numAggregationFunctions = _aggregationFunctions.length; int length = valueBlock.getNumDocs(); for (int i = 0; i < numAggregationFunctions; i++) { + if (_preAggregatedResults != null && _preAggregatedResults[i] != null) { + continue; // skip — already resolved from metadata + } AggregationFunction aggregationFunction = _aggregationFunctions[i]; aggregationFunction.aggregate(length, _aggregationResultHolders[i], AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, valueBlock)); @@ -55,7 +64,11 @@ public List getResult() { int numFunctions = _aggregationFunctions.length; List aggregationResults = new ArrayList<>(numFunctions); for (int i = 0; i < numFunctions; i++) { - aggregationResults.add(_aggregationFunctions[i].extractAggregationResult(_aggregationResultHolders[i])); + if (_preAggregatedResults != null && _preAggregatedResults[i] != null) { + aggregationResults.add(_preAggregatedResults[i]); + } else { + aggregationResults.add(_aggregationFunctions[i].extractAggregationResult(_aggregationResultHolders[i])); + } } return aggregationResults; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java index 79913301a3df..6b05c1036230 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java @@ -18,11 +18,20 @@ */ package org.apache.pinot.core.query.aggregation.function; +import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus; +import com.dynatrace.hash4j.distinctcount.UltraLogLog; +import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; import it.unimi.dsi.fastutil.floats.FloatArrayList; +import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.longs.LongArrayList; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.objects.ObjectArrayList; +import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +39,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.Pair; @@ -41,6 +51,7 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.config.QueryOptionsUtils; import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.common.ObjectSerDeUtils; import org.apache.pinot.core.operator.BaseProjectOperator; import org.apache.pinot.core.operator.blocks.ValueBlock; import org.apache.pinot.core.operator.filter.BaseFilterOperator; @@ -51,9 +62,16 @@ import org.apache.pinot.core.plan.ProjectPlanNode; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.startree.StarTreeUtils; +import org.apache.pinot.segment.local.customobject.MinMaxRangePair; +import org.apache.pinot.segment.local.utils.UltraLogLogUtils; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.segment.spi.SegmentContext; +import org.apache.pinot.segment.spi.datasource.DataSource; +import org.apache.pinot.segment.spi.index.reader.Dictionary; import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.query.QueryThreadContext; +import org.apache.pinot.spi.utils.ByteArray; /** @@ -61,6 +79,7 @@ */ @SuppressWarnings({"rawtypes", "unchecked"}) public class AggregationFunctionUtils { + private AggregationFunctionUtils() { } @@ -450,4 +469,333 @@ public static String getResultColumnName(AggregationFunction aggregationFunction } return columnName; } + + /** + * Gets the aggregation result without scanning the segment. + * This is used for non-scan based aggregation operator. + * @param aggregationFunction + * @param dataSource + * @param numTotalDocs + * @return + */ + public static Object getAggregationResult(AggregationFunction aggregationFunction, DataSource dataSource, + int numTotalDocs, String explainName) { + Object result; + switch (aggregationFunction.getType()) { + case COUNT: + result = (long) numTotalDocs; + break; + case MIN: + case MINMV: + result = getMinValueNumeric(dataSource); + break; + case MINLONG: + result = getMinValueLong(dataSource); + break; + case MINSTRING: + assert dataSource.getDictionary() != null; + result = dataSource.getDictionary().getMinVal(); + break; + case MAX: + case MAXMV: + result = getMaxValueNumeric(dataSource); + break; + case MAXLONG: + result = getMaxValueLong(dataSource); + break; + case MAXSTRING: + assert dataSource.getDictionary() != null; + result = dataSource.getDictionary().getMaxVal(); + break; + case MINMAXRANGE: + case MINMAXRANGEMV: + result = new MinMaxRangePair(getMinValueNumeric(dataSource), getMaxValueNumeric(dataSource)); + break; + case DISTINCTCOUNT: + case DISTINCTSUM: + case DISTINCTAVG: + case DISTINCTCOUNTMV: + case DISTINCTSUMMV: + case DISTINCTAVGMV: + result = getDistinctValueSet(Objects.requireNonNull(dataSource.getDictionary()), explainName); + break; + case DISTINCTCOUNTOFFHEAP: + result = ((DistinctCountOffHeapAggregationFunction) aggregationFunction).extractAggregationResult( + Objects.requireNonNull(dataSource.getDictionary())); + break; + case DISTINCTCOUNTHLL: + case DISTINCTCOUNTHLLMV: + result = getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountHLLAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTRAWHLL: + case DISTINCTCOUNTRAWHLLMV: + result = getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()), + ((DistinctCountRawHLLAggregationFunction) aggregationFunction).getDistinctCountHLLAggregationFunction(), + explainName); + break; + case DISTINCTCOUNTHLLPLUS: + case DISTINCTCOUNTHLLPLUSMV: + result = getDistinctCountHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountHLLPlusAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTRAWHLLPLUS: + case DISTINCTCOUNTRAWHLLPLUSMV: + result = getDistinctCountHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), + ((DistinctCountRawHLLPlusAggregationFunction) aggregationFunction) + .getDistinctCountHLLPlusAggregationFunction(), explainName); + break; + case SEGMENTPARTITIONEDDISTINCTCOUNT: + result = (long) Objects.requireNonNull(dataSource.getDictionary()).length(); + break; + case DISTINCTCOUNTSMARTHLL: + result = getDistinctCountSmartHLLResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountSmartHLLAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTSMARTHLLPLUS: + result = getDistinctCountSmartHLLPlusResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountSmartHLLPlusAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTULL: + result = getDistinctCountULLResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountULLAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTSMARTULL: + result = getDistinctCountSmartULLResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountSmartULLAggregationFunction) aggregationFunction, explainName); + break; + case DISTINCTCOUNTRAWULL: + result = getDistinctCountULLResult(Objects.requireNonNull(dataSource.getDictionary()), + (DistinctCountULLAggregationFunction) aggregationFunction, explainName); + break; + default: + throw new IllegalStateException( + "Non-scan based aggregation operator does not support function type: " + aggregationFunction.getType()); + } + + return result; + } + + private static Double getMinValueNumeric(DataSource dataSource) { + Dictionary dictionary = dataSource.getDictionary(); + if (dictionary != null) { + return toDouble(dictionary.getMinVal()); + } + return toDouble(dataSource.getDataSourceMetadata().getMinValue()); + } + + private static Long getMinValueLong(DataSource dataSource) { + FieldSpec.DataType dataType = dataSource.getDataSourceMetadata().getDataType().getStoredType(); + Preconditions.checkArgument( + dataType == FieldSpec.DataType.LONG || dataType == FieldSpec.DataType.INT, + "MINLONG aggregation function can only be applied to columns of integer types"); + Dictionary dictionary = dataSource.getDictionary(); + if (dictionary != null) { + return ((Number) dictionary.getMinVal()).longValue(); + } + return ((Number) dataSource.getDataSourceMetadata().getMinValue()).longValue(); + } + + private static Double getMaxValueNumeric(DataSource dataSource) { + Dictionary dictionary = dataSource.getDictionary(); + if (dictionary != null) { + return toDouble(dictionary.getMaxVal()); + } + return toDouble(dataSource.getDataSourceMetadata().getMaxValue()); + } + + private static Long getMaxValueLong(DataSource dataSource) { + FieldSpec.DataType dataType = dataSource.getDataSourceMetadata().getDataType().getStoredType(); + Preconditions.checkArgument( + dataType == FieldSpec.DataType.LONG || dataType == FieldSpec.DataType.INT, + "MAXLONG aggregation function can only be applied to columns of integer types"); + Dictionary dictionary = dataSource.getDictionary(); + if (dictionary != null) { + return ((Number) dictionary.getMaxVal()).longValue(); + } + return ((Number) dataSource.getDataSourceMetadata().getMaxValue()).longValue(); + } + + private static Double toDouble(Comparable value) { + if (value instanceof Double) { + return (Double) value; + } else if (value instanceof Number) { + return ((Number) value).doubleValue(); + } else { + return Double.parseDouble(value.toString()); + } + } + + private static Set getDistinctValueSet(Dictionary dictionary, String explainName) { + int dictionarySize = dictionary.length(); + switch (dictionary.getValueType()) { + case INT: + IntOpenHashSet intSet = new IntOpenHashSet(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + intSet.add(dictionary.getIntValue(dictId)); + } + return intSet; + case LONG: + LongOpenHashSet longSet = new LongOpenHashSet(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + longSet.add(dictionary.getLongValue(dictId)); + } + return longSet; + case FLOAT: + FloatOpenHashSet floatSet = new FloatOpenHashSet(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + floatSet.add(dictionary.getFloatValue(dictId)); + } + return floatSet; + case DOUBLE: + DoubleOpenHashSet doubleSet = new DoubleOpenHashSet(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + doubleSet.add(dictionary.getDoubleValue(dictId)); + } + return doubleSet; + case STRING: + ObjectOpenHashSet stringSet = new ObjectOpenHashSet<>(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + stringSet.add(dictionary.getStringValue(dictId)); + } + return stringSet; + case BYTES: + ObjectOpenHashSet bytesSet = new ObjectOpenHashSet<>(dictionarySize); + for (int dictId = 0; dictId < dictionarySize; dictId++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(dictId, explainName); + bytesSet.add(new ByteArray(dictionary.getBytesValue(dictId))); + } + return bytesSet; + default: + throw new IllegalStateException(); + } + } + + private static HyperLogLog getDistinctValueHLL(Dictionary dictionary, int log2m, String explainName) { + HyperLogLog hll = new HyperLogLog(log2m); + int length = dictionary.length(); + for (int i = 0; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + hll.offer(dictionary.get(i)); + } + return hll; + } + + private static UltraLogLog getDistinctValueULL(Dictionary dictionary, int p, String explainName) { + UltraLogLog ull = UltraLogLog.create(p); + int length = dictionary.length(); + for (int i = 0; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + Object value = dictionary.get(i); + UltraLogLogUtils.hashObject(value).ifPresent(ull::add); + } + return ull; + } + + private static HyperLogLogPlus getDistinctValueHLLPlus(Dictionary dictionary, int p, int sp, String explainName) { + HyperLogLogPlus hllPlus = new HyperLogLogPlus(p, sp); + int length = dictionary.length(); + for (int i = 0; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + hllPlus.offer(dictionary.get(i)); + } + return hllPlus; + } + + private static HyperLogLog getDistinctCountHLLResult(Dictionary dictionary, + DistinctCountHLLAggregationFunction function, String explainName) { + if (dictionary.getValueType() == FieldSpec.DataType.BYTES) { + // Treat BYTES value as serialized HyperLogLog + try { + QueryThreadContext.checkTerminationAndSampleUsage(explainName); + HyperLogLog hll = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(dictionary.getBytesValue(0)); + int length = dictionary.length(); + for (int i = 1; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + hll.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(dictionary.getBytesValue(i))); + } + return hll; + } catch (Exception e) { + throw new RuntimeException("Caught exception while merging HyperLogLogs", e); + } + } else { + return getDistinctValueHLL(dictionary, function.getLog2m(), explainName); + } + } + + private static HyperLogLogPlus getDistinctCountHLLPlusResult(Dictionary dictionary, + DistinctCountHLLPlusAggregationFunction function, String explainName) { + if (dictionary.getValueType() == FieldSpec.DataType.BYTES) { + // Treat BYTES value as serialized HyperLogLogPlus + try { + QueryThreadContext.checkTerminationAndSampleUsage(explainName); + HyperLogLogPlus hllplus = ObjectSerDeUtils.HYPER_LOG_LOG_PLUS_SER_DE.deserialize(dictionary.getBytesValue(0)); + int length = dictionary.length(); + for (int i = 1; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + hllplus.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_PLUS_SER_DE.deserialize(dictionary.getBytesValue(i))); + } + return hllplus; + } catch (Exception e) { + throw new RuntimeException("Caught exception while merging HyperLogLogPluses", e); + } + } else { + return getDistinctValueHLLPlus(dictionary, function.getP(), function.getSp(), explainName); + } + } + + private static Object getDistinctCountSmartHLLResult(Dictionary dictionary, + DistinctCountSmartHLLAggregationFunction function, String explainPlanName) { + if (dictionary.length() > function.getThreshold()) { + // Store values into a HLL when the dictionary size exceeds the conversion threshold + return getDistinctValueHLL(dictionary, function.getLog2m(), explainPlanName); + } else { + return getDistinctValueSet(dictionary, explainPlanName); + } + } + + private static Object getDistinctCountSmartHLLPlusResult(Dictionary dictionary, + DistinctCountSmartHLLPlusAggregationFunction function, String explainName) { + if (dictionary.length() > function.getThreshold()) { + // Store values into a HLLPlus when the dictionary size exceeds the conversion threshold + return getDistinctValueHLLPlus(dictionary, function.getP(), function.getSp(), explainName); + } else { + return getDistinctValueSet(dictionary, explainName); + } + } + + private static UltraLogLog getDistinctCountULLResult(Dictionary dictionary, + DistinctCountULLAggregationFunction function, String explainName) { + if (dictionary.getValueType() == FieldSpec.DataType.BYTES) { + // Treat BYTES value as serialized UltraLogLog and merge + try { + QueryThreadContext.checkTerminationAndSampleUsage(explainName); + UltraLogLog ull = ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize(dictionary.getBytesValue(0)); + int length = dictionary.length(); + for (int i = 1; i < length; i++) { + QueryThreadContext.checkTerminationAndSampleUsagePeriodically(i, explainName); + ull.add(ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize(dictionary.getBytesValue(i))); + } + return ull; + } catch (Exception e) { + throw new RuntimeException("Caught exception while merging UltraLogLogs", e); + } + } else { + return getDistinctValueULL(dictionary, function.getP(), explainName); + } + } + + private static Object getDistinctCountSmartULLResult(Dictionary dictionary, + DistinctCountSmartULLAggregationFunction function, String explainName) { + if (dictionary.length() > function.getThreshold()) { + return getDistinctValueULL(dictionary, function.getP(), explainName); + } else { + return getDistinctValueSet(dictionary, explainName); + } + } }