diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoder.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoder.java index 297e04638299..bf402deb1bc3 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoder.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoder.java @@ -18,11 +18,13 @@ */ package org.apache.pinot.plugin.inputformat.arrow; - +import com.google.common.base.Preconditions; import java.io.ByteArrayInputStream; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; @@ -30,40 +32,61 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.data.readers.RecordExtractorConfig; +import org.apache.pinot.spi.plugin.PluginManager; import org.apache.pinot.spi.stream.StreamMessageDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * ArrowMessageDecoder is used to decode Apache Arrow IPC format messages into Pinot GenericRow. - * This decoder handles Arrow streaming format and converts Arrow data to Pinot's columnar format. - */ + +/// Decodes Apache Arrow IPC stream-format messages into Pinot [GenericRow]s. The output shape depends on the Arrow +/// batch's row count: +/// - 0 row → returns `null` (nothing to ingest). +/// - 1 row → the single row's fields are populated directly into the destination [GenericRow]. +/// - multiple rows → the rows are wrapped in a `List` stored under [GenericRow#MULTIPLE_RECORDS_KEY] +/// on the destination. public class ArrowMessageDecoder implements StreamMessageDecoder { public static final String ARROW_ALLOCATOR_LIMIT = "arrow.allocator.limit"; public static final String DEFAULT_ALLOCATOR_LIMIT = "268435456"; // 256MB default - private static final Logger logger = LoggerFactory.getLogger(ArrowMessageDecoder.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowMessageDecoder.class); + private final ArrowRecordExtractor.Record _record = new ArrowRecordExtractor.Record(); + private ArrowRecordExtractor _extractor; private String _streamTopicName; private RootAllocator _allocator; - private ArrowToGenericRowConverter _converter; @Override public void init(Map props, Set fieldsToRead, String topicName) throws Exception { + // Resolve the extractor + config classes from props. Defaults to `ArrowRecordExtractor` / + // `ArrowRecordExtractorConfig`; user-supplied extractors must extend [ArrowRecordExtractor] + // so the per-batch `setReader` hook is honored. + String extractorClass = props.get(RECORD_EXTRACTOR_CONFIG_KEY); + String configClass = props.get(RECORD_EXTRACTOR_CONFIG_CONFIG_KEY); + if (extractorClass == null) { + extractorClass = ArrowRecordExtractor.class.getName(); + configClass = ArrowRecordExtractorConfig.class.getName(); + } + RecordExtractorConfig extractorConfig = null; + if (configClass != null) { + extractorConfig = PluginManager.get().createInstance(configClass); + extractorConfig.init(props); + } + // Validate the extractor extends ArrowRecordExtractor: the decoder loop calls `setReader` per batch, + // and that hook only exists on this base class. + Object extractor = PluginManager.get().createInstance(extractorClass); + Preconditions.checkState(extractor instanceof ArrowRecordExtractor, + "Record extractor class %s must extend ArrowRecordExtractor", extractorClass); + _extractor = (ArrowRecordExtractor) extractor; + _extractor.init(fieldsToRead, extractorConfig); _streamTopicName = topicName; // Initialize Arrow allocator with configurable memory limit - long allocatorLimit = - Long.parseLong(props.getOrDefault(ARROW_ALLOCATOR_LIMIT, DEFAULT_ALLOCATOR_LIMIT)); + long allocatorLimit = Long.parseLong(props.getOrDefault(ARROW_ALLOCATOR_LIMIT, DEFAULT_ALLOCATOR_LIMIT)); _allocator = new RootAllocator(allocatorLimit); - // Initialize Arrow to GenericRow converter (processes all fields) - _converter = new ArrowToGenericRowConverter(); - - logger.info( - "Initialized ArrowMessageDecoder for topic: {} with allocator limit: {} bytes", - topicName, + LOGGER.info("Initialized ArrowMessageDecoder for topic: {} with allocator limit: {} bytes", topicName, allocatorLimit); } @@ -73,24 +96,42 @@ public GenericRow decode(byte[] payload, GenericRow destination) { try (ByteArrayInputStream inputStream = new ByteArrayInputStream(payload); ReadableByteChannel channel = Channels.newChannel(inputStream); ArrowStreamReader reader = new ArrowStreamReader(channel, _allocator)) { + if (!reader.loadNextBatch()) { + LOGGER.warn("No data found in Arrow message for topic: {}", _streamTopicName); + return null; + } - // Read the Arrow schema and data VectorSchemaRoot root = reader.getVectorSchemaRoot(); - if (!reader.loadNextBatch()) { - logger.warn("No data found in Arrow message for topic: {}", _streamTopicName); + int rowCount = root.getRowCount(); + if (rowCount == 0) { return null; } - // Convert Arrow data to GenericRow using converter - GenericRow row = _converter.convert(reader, root, destination); + if (destination == null) { + destination = new GenericRow(); + } + _extractor.setReader(reader); + _extractor.prepareBatch(_record); - return row; + if (rowCount == 1) { + // Single row — fill destination directly (the GenericRow is the row). + _record.setRowId(0); + _extractor.extract(_record, destination); + return destination; + } + + // Multiple rows — wrap them under MULTIPLE_RECORDS_KEY. + List rows = new ArrayList<>(rowCount); + for (int rowId = 0; rowId < rowCount; rowId++) { + _record.setRowId(rowId); + GenericRow row = new GenericRow(); + _extractor.extract(_record, row); + rows.add(row); + } + destination.putValue(GenericRow.MULTIPLE_RECORDS_KEY, rows); + return destination; } catch (Exception e) { - logger.error( - "Error decoding Arrow message for stream topic {} : {}", - _streamTopicName, - Arrays.toString(payload), - e); + LOGGER.error("Error decoding Arrow message for stream topic {} ({} bytes)", _streamTopicName, payload.length, e); return null; } } @@ -103,11 +144,12 @@ public GenericRow decode(byte[] payload, int offset, int length, GenericRow dest /** Clean up resources */ public void close() { + _record.close(); if (_allocator != null) { try { _allocator.close(); } catch (Exception e) { - logger.warn("Error closing Arrow allocator", e); + LOGGER.warn("Error closing Arrow allocator", e); } } } diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java new file mode 100644 index 000000000000..4bdfadeb25b1 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractor.java @@ -0,0 +1,454 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import com.google.common.collect.Maps; +import java.io.IOException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.pinot.spi.data.readers.BaseRecordExtractor; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.data.readers.RecordExtractorConfig; +import org.apache.pinot.spi.utils.TimestampUtils; + + +/// Extracts a single Arrow row into a [GenericRow]. Reader-scoped state ([VectorSchemaRoot] + +/// dictionary map) is bound once via [#setReader]; per-row [#extract] calls take a [Record] holding +/// only the row index. Dispatch is schema-driven — each column is walked using its [Field], so the +/// logical type drives the conversion rather than the runtime Java type of the value. +/// +/// **Scalars** (Arrow type → Java output): +/// - `Bool` → `Boolean` +/// - `Int(8/16)` → `Integer` (widened from `Byte` / `Short`) +/// - `Int(32)` → `Integer` +/// - `Int(64)` → `Long` +/// - `FloatingPoint(SINGLE)` → `Float` +/// - `FloatingPoint(DOUBLE)` → `Double` +/// - `Decimal` → `BigDecimal` +/// - `Utf8` / `LargeUtf8` → `String` (via `Text.toString()`) +/// - `Binary` / `LargeBinary` / `FixedSizeBinary` → `byte[]` +/// - `Null` → `null` (every row is null by definition) +/// +/// **Temporal** (per the schema's `DateUnit` / `TimeUnit`): +/// - `Timestamp` no-TZ → [Timestamp] (Arrow surfaces all four units as `LocalDateTime`; interpreted +/// as a UTC instant) +/// - `Timestamp` with-TZ → [Timestamp] (Arrow surfaces all four units as `Long` epoch; constructed +/// per the schema's `TimeUnit`, sub-millisecond precision preserved via [TimestampUtils]) +/// - `Date` → [LocalDate] (`DateDayVector` surfaces as `Integer` raw days; `DateMilliVector` as +/// `LocalDateTime` at UTC midnight — both reduce to a calendar date) +/// - `Time` → [LocalTime] (`TimeSecVector` as `Integer`, `TimeMilliVector` as `LocalDateTime`, +/// `TimeMicroVector` / `TimeNanoVector` as `Long` — all collapse onto nanoseconds-since-midnight) +/// - `Interval` / `Duration` → ISO-8601 `String` via `value.toString()` — `java.time.Period` / +/// `java.time.Duration` / `PeriodDuration` all have meaningful toString (e.g. `"P1Y2M"`, +/// `"PT5H30M"`, `"P1Y2M3D PT4H5M6S"`) +/// +/// With `extractRawTimeValues = true` ([ArrowRecordExtractorConfig]) the `Date` / `Time` / +/// `Timestamp` cases bypass the contract conversion: `Date` → `int` days-since-epoch (regardless of +/// `DateUnit` — `DateMilli` is always UTC midnight, so reducing to days is lossless); `Time` / +/// `Timestamp` → raw `int` / `long` in the schema's `TimeUnit`. `Interval` / `Duration` are +/// unaffected. Temporal values that surface inside a `Union` branch don't see the bypass either — +/// the chosen branch's [Field] isn't visible from the value alone, so we can't pick a unit; they +/// always coerce to `Timestamp` UTC. +/// +/// **Complex** (recurse with the [Field]'s child fields): +/// - `List` / `LargeList` / `FixedSizeList` → `Object[]` +/// - `Struct` → `Map` +/// - `Map` → `Map` (Arrow's `List>` entry list is flattened; +/// keys are stringified per [BaseRecordExtractor#stringifyMapKey]) +/// - `Union` → recursively dispatched by the value's runtime Java type (the chosen branch isn't +/// visible from the value alone — nested complex sub-branches fall back to `value.toString()`) +/// +/// **Other**: +/// - dictionary-encoded vector → decoded against the bound dictionary, then dispatched on the +/// decoded vector's [Field] (so the logical type — e.g. `Utf8` — drives conversion, not the +/// dictionary's index type) +/// +/// Unrecognized types (`NONE` / future Arrow additions) throw [IllegalStateException]. +/// +/// **Quirks worth knowing:** +/// - `UInt2Vector` (unsigned 16-bit) returns `Character`, not a `Number` — Arrow's Java bindings +/// use `char` as the only natively unsigned primitive. We widen to `int` per the contract. +/// - `DateDayVector` / `DateMilliVector` return *different* Java types (`Integer` vs +/// `LocalDateTime`) for the same logical `DATE` type — historical asymmetry in Arrow's API. +public class ArrowRecordExtractor extends BaseRecordExtractor { + + /// Per-batch context used by [#extract]. Holds the row index plus the active vectors for the + /// current batch — decoded copies for dictionary-encoded columns, the raw [FieldVector] otherwise + /// (parallel to the extractor's `_fieldVectors`). Decoded vectors are owned by this Record and + /// closed via [#close]; closing is also implicit on the next [#prepareBatch] call. + /// + /// Lifecycle: caller invokes [#prepareBatch] once at the start of a batch, then [#setRowId] + + /// [#extract] per row. Reusable across batches/files; close once at end of iteration. + public static final class Record implements AutoCloseable { + private int _rowId; + private ValueVector[] _activeVectors; + private boolean[] _ownsVector; + + public void setRowId(int rowId) { + _rowId = rowId; + } + + @Override + public void close() { + if (_activeVectors != null) { + for (int i = 0; i < _activeVectors.length; i++) { + if (_ownsVector[i]) { + _activeVectors[i].close(); + } + } + _activeVectors = null; + _ownsVector = null; + } + } + } + + private boolean _extractRawTimeValues; + + // Reader-scoped state — initialized in [#setReader], read by per-row [#extract]. The dictionary map + // is held so per-row lookups don't re-traverse the reader; the field vectors are pre-resolved against + // the include list so the per-row loop is a flat array walk (names are read inline from each field + // vector — `Field#getName` is a plain getter). + private Map _dictionaries; + private FieldVector[] _fieldVectors; + + @Override + protected void initConfig(@Nullable RecordExtractorConfig config) { + if (config instanceof ArrowRecordExtractorConfig) { + _extractRawTimeValues = ((ArrowRecordExtractorConfig) config).isExtractRawTimeValues(); + } + } + + /// Binds the extractor to `reader` for the upcoming run of [#extract] calls. Must be called before + /// [#prepareBatch] — once per file (`ArrowRecordReader`) or per `decode()` call (`ArrowMessageDecoder`). + /// Resolves the include list against the reader's [VectorSchemaRoot] and stashes the dictionary map. + public void setReader(ArrowReader reader) + throws IOException { + _dictionaries = reader.getDictionaryVectors(); + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + List fieldVectors = root.getFieldVectors(); + if (_extractAll) { + _fieldVectors = fieldVectors.toArray(new FieldVector[0]); + } else { + List matched = new ArrayList<>(_fields.size()); + for (FieldVector fieldVector : fieldVectors) { + if (_fields.contains(fieldVector.getField().getName())) { + matched.add(fieldVector); + } + } + _fieldVectors = matched.toArray(new FieldVector[0]); + } + } + + /// Prepares `record` for the current batch: closes any prior decoded vectors and decodes each + /// dictionary-encoded column from `_fieldVectors` once into `record._activeVectors[i]`. + /// Non-dictionary columns share the raw [FieldVector] reference (no copy). Must be called after + /// each `loadNextBatch` and before the per-row [#extract] loop. + public void prepareBatch(Record record) { + record.close(); + int numFields = _fieldVectors.length; + ValueVector[] activeVectors = new ValueVector[numFields]; + boolean[] ownsVector = new boolean[numFields]; + for (int i = 0; i < numFields; i++) { + FieldVector fieldVector = _fieldVectors[i]; + DictionaryEncoding encoding = fieldVector.getField().getDictionary(); + if (encoding != null) { + activeVectors[i] = DictionaryEncoder.decode(fieldVector, _dictionaries.get(encoding.getId())); + ownsVector[i] = true; + } else { + activeVectors[i] = fieldVector; + } + } + record._activeVectors = activeVectors; + record._ownsVector = ownsVector; + } + + /// Reads each included column at `from._rowId` and dispatches by the active vector's [Field]'s + /// logical type. The active vector is the dictionary-decoded vector for dictionary-encoded + /// columns (so dispatch sees the value type — e.g. `Utf8` — not the dictionary's index type), or + /// the raw [FieldVector] otherwise. + @Override + public GenericRow extract(Record from, GenericRow to) { + FieldVector[] fieldVectors = _fieldVectors; + ValueVector[] activeVectors = from._activeVectors; + for (int i = 0; i < fieldVectors.length; i++) { + ValueVector activeVector = activeVectors[i]; + Object rawValue = activeVector.getObject(from._rowId); + to.putValue(fieldVectors[i].getField().getName(), + rawValue != null ? convert(activeVector.getField(), rawValue) : null); + } + return to; + } + + /// Schema-driven dispatch — one branch per [ArrowType.ArrowTypeID]; complex types recurse with + /// their child [Field]s, scalars normalize per the contract. + @Nullable + private Object convert(Field field, Object value) { + ArrowType type = field.getType(); + switch (type.getTypeID()) { + // Pass-through — Arrow boxes these directly into the contract output type. + case Bool: // Boolean + case FloatingPoint: // Float / Double + case Decimal: // BigDecimal + case Binary: // byte[] + case LargeBinary: // byte[] + case FixedSizeBinary: // byte[] + return value; + // toString — `Utf8` / `LargeUtf8` produce `String`; `Interval` / `Duration` produce ISO-8601 + // (`java.time.Period` / `java.time.Duration` / `PeriodDuration` all have meaningful toString). + case Utf8: + case LargeUtf8: + case Interval: + case Duration: + return value.toString(); + // Integer — `Byte` widens to `Integer` per contract (sign-extended for signed `TinyIntVector`, + // zero-extended via `& 0xFF` for unsigned `UInt1Vector`); `Short` (signed `SmallIntVector`) + // sign-extends; `Character` (unsigned 16, from `UInt2Vector`) widens to its `int` code point; + // `Integer` / `Long` pass through. + case Int: + if (value instanceof Byte) { + int v = (Byte) value; + return ((ArrowType.Int) type).getIsSigned() ? v : v & 0xFF; + } + if (value instanceof Short) { + return ((Short) value).intValue(); + } + if (value instanceof Character) { + return (int) (Character) value; + } + return value; + // Null — NullVector.getObject always returns null; extractValue short-circuits on null, so + // this branch is unreachable in practice. Defensive return. + case Null: + return null; + // Logical temporal — schema's `TimeUnit` drives the conversion. + case Timestamp: + return convertTimestamp((ArrowType.Timestamp) type, value); + case Date: + return convertDate((ArrowType.Date) type, value); + case Time: + return convertTime((ArrowType.Time) type, value); + // Multi-value — `List` (and primitive-array lists) → `Object[]`. + case List: + case LargeList: + case FixedSizeList: + return convertList(field.getChildren().get(0), (List) value); + // Map / nested complex types. + case Map: + // The Map field's children are [entriesStruct]; the entries struct's children are + // [keyField, valueField] (named per MapVector.KEY_NAME / VALUE_NAME). + Field entriesField = field.getChildren().get(0); + return convertMap(entriesField.getChildren().get(0), entriesField.getChildren().get(1), (List) value); + case Struct: + return convertStruct(field.getChildren(), (Map) value); + case Union: + // The chosen branch isn't visible from the resolved value alone — dispatch by the value's + // runtime Java type. Nested complex sub-branches fall back to `value.toString()`. + return convertByRuntimeType(value); + default: + // `NONE` is a placeholder; any other ID is a future Arrow addition. + throw new IllegalStateException("Unsupported Arrow type: " + type + " for field: " + field.getName()); + } + } + + /// Constructs a [Timestamp] from an Arrow `Timestamp` value. No-TZ vectors surface as + /// `LocalDateTime` (interpreted as UTC); with-TZ vectors surface as `Long` epoch counted in the + /// schema's `TimeUnit`. Sub-millisecond precision is preserved via [TimestampUtils]. + /// With [#_extractRawTimeValues] the raw `long` epoch in the schema's `TimeUnit` is returned. + private Object convertTimestamp(ArrowType.Timestamp type, Object value) { + if (_extractRawTimeValues) { + if (value instanceof LocalDateTime) { + // No-TZ vector — convert the LocalDateTime back to an epoch `long` in the declared unit. + Instant instant = ((LocalDateTime) value).toInstant(ZoneOffset.UTC); + return toEpochInUnit(instant, type.getUnit()); + } + // With-TZ vector — already raw `long` in the declared unit. + return value; + } + if (value instanceof LocalDateTime) { + return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); + } + long raw = ((Number) value).longValue(); + switch (type.getUnit()) { + case SECOND: + return new Timestamp(raw * 1000L); + case MILLISECOND: + return new Timestamp(raw); + case MICROSECOND: + return TimestampUtils.fromMicrosSinceEpoch(raw); + case NANOSECOND: + return TimestampUtils.fromNanosSinceEpoch(raw); + default: + throw new IllegalStateException("Unsupported Timestamp unit: " + type.getUnit()); + } + } + + private static long toEpochInUnit(Instant instant, org.apache.arrow.vector.types.TimeUnit unit) { + switch (unit) { + case SECOND: + return instant.getEpochSecond(); + case MILLISECOND: + return instant.toEpochMilli(); + case MICROSECOND: + return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000L), instant.getNano() / 1_000L); + case NANOSECOND: + return Math.addExact(Math.multiplyExact(instant.getEpochSecond(), 1_000_000_000L), instant.getNano()); + default: + throw new IllegalStateException("Unsupported Timestamp unit: " + unit); + } + } + + /// Reduces an Arrow `Date` value to its contract Java type ([LocalDate]), or to `int` + /// days-since-epoch when [#_extractRawTimeValues] is set. `DateDayVector` surfaces as `Integer` + /// raw days; `DateMilliVector` surfaces as `LocalDateTime` at UTC midnight. + private Object convertDate(ArrowType.Date type, Object value) { + int days; + switch (type.getUnit()) { + case DAY: + days = (Integer) value; + break; + case MILLISECOND: + days = (int) ((LocalDateTime) value).toLocalDate().toEpochDay(); + break; + default: + throw new IllegalStateException("Unsupported Date unit: " + type.getUnit()); + } + return _extractRawTimeValues ? days : LocalDate.ofEpochDay(days); + } + + /// Constructs a [LocalTime] from an Arrow `Time` value, dispatched by the schema's `TimeUnit`: + /// `TimeMilliVector` surfaces as `LocalDateTime`; `TimeSecVector` as `Integer`; + /// `TimeMicroVector` / `TimeNanoVector` as `Long`. All collapse onto nanoseconds-since-midnight. + /// With [#_extractRawTimeValues] the raw count in the schema's `TimeUnit` is returned instead. + private Object convertTime(ArrowType.Time type, Object value) { + if (_extractRawTimeValues) { + if (value instanceof LocalDateTime) { + // `TimeMilliVector` surfaces as `LocalDateTime`; raw is `int` ms since midnight. + return (int) (((LocalDateTime) value).toLocalTime().toNanoOfDay() / 1_000_000L); + } + // `TimeSecVector` (Integer) / `TimeMicroVector` / `TimeNanoVector` (Long) — already raw. + return value; + } + if (value instanceof LocalDateTime) { + return ((LocalDateTime) value).toLocalTime(); + } + long raw = ((Number) value).longValue(); + switch (type.getUnit()) { + case SECOND: + return LocalTime.ofSecondOfDay(raw); + case MILLISECOND: + return LocalTime.ofNanoOfDay(raw * 1_000_000L); + case MICROSECOND: + return LocalTime.ofNanoOfDay(raw * 1_000L); + case NANOSECOND: + return LocalTime.ofNanoOfDay(raw); + default: + throw new IllegalStateException("Unsupported Time unit: " + type.getUnit()); + } + } + + private Object[] convertList(Field elementField, List list) { + int size = list.size(); + Object[] result = new Object[size]; + int i = 0; + for (Object element : list) { + result[i++] = element != null ? convert(elementField, element) : null; + } + return result; + } + + /// Flattens an Arrow `Map` column's entry list (`List>`) into a + /// `Map`, recursing into each value via [#convert] and stringifying each key via + /// [BaseRecordExtractor#stringifyMapKey] per the contract. Entries with a `null` key (input or + /// post-conversion) are dropped. + private Map convertMap(Field keyField, Field valueField, List entries) { + Map result = Maps.newLinkedHashMapWithExpectedSize(entries.size()); + for (Object entry : entries) { + if (entry == null) { + continue; + } + Map entryMap = (Map) entry; + Object rawKey = entryMap.get(MapVector.KEY_NAME); + if (rawKey == null) { + continue; + } + Object convertedKey = convert(keyField, rawKey); + if (convertedKey == null) { + continue; + } + Object rawValue = entryMap.get(MapVector.VALUE_NAME); + result.put(stringifyMapKey(convertedKey), rawValue != null ? convert(valueField, rawValue) : null); + } + return result; + } + + private Map convertStruct(List childFields, Map value) { + Map result = Maps.newHashMapWithExpectedSize(childFields.size()); + for (Field childField : childFields) { + String name = childField.getName(); + Object rawValue = value.get(name); + result.put(name, rawValue != null ? convert(childField, rawValue) : null); + } + return result; + } + + /// Runtime-type dispatch used by the `Union` case (where the chosen branch isn't accessible + /// from the resolved value). Mirrors the scalar handling of [#convert] for the common Arrow + /// boxed types; nested complex types fall back to `value.toString()` because their child + /// [Field]s aren't reachable from here. + private static Object convertByRuntimeType(Object value) { + if (value instanceof Number) { + if (value instanceof Byte || value instanceof Short) { + return ((Number) value).intValue(); + } + return value; + } + if (value instanceof Boolean || value instanceof byte[]) { + return value; + } + if (value instanceof Character) { + // `UInt2Vector` surfaces as `Character`; widen to `int` per the Int(16) contract. + return (int) (Character) value; + } + if (value instanceof LocalDateTime) { + // Ambiguous between Timestamp / Date / Time — best-effort: treat as Timestamp UTC. + return Timestamp.from(((LocalDateTime) value).toInstant(ZoneOffset.UTC)); + } + // `Text` (Utf8 / LargeUtf8), `Period` / `Duration` / `PeriodDuration` (Interval / Duration), and + // anything unrecognized fall through to `toString()`. + return value.toString(); + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorConfig.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorConfig.java new file mode 100644 index 000000000000..72debecdb502 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorConfig.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.util.Map; +import org.apache.pinot.spi.data.readers.RecordExtractorConfig; + + +/// Config for [ArrowRecordExtractor]. One flag: +/// - `extractRawTimeValues` (default `false`) — opt out of `Date` / `Time` / `Timestamp` conversion +/// at the extractor boundary, surfacing the raw integer in the column's declared `TimeUnit` +/// (`DateUnit` for `Date`) instead of the contract Java type. See [ArrowRecordExtractor] for +/// the per-type matrix. +public class ArrowRecordExtractorConfig implements RecordExtractorConfig { + public static final String EXTRACT_RAW_TIME_VALUES = "extractRawTimeValues"; + + private boolean _extractRawTimeValues; + + @Override + public void init(Map props) { + _extractRawTimeValues = Boolean.parseBoolean(props.get(EXTRACT_RAW_TIME_VALUES)); + } + + public boolean isExtractRawTimeValues() { + return _extractRawTimeValues; + } + + public void setExtractRawTimeValues(boolean extractRawTimeValues) { + _extractRawTimeValues = extractRawTimeValues; + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReader.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReader.java index 246c519b2c34..9d8695aae1ab 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReader.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReader.java @@ -35,8 +35,9 @@ * Record reader for Apache Arrow IPC file format. */ public class ArrowRecordReader implements RecordReader { + private final ArrowRecordExtractor _extractor = new ArrowRecordExtractor(); + private final ArrowRecordExtractor.Record _record = new ArrowRecordExtractor.Record(); private File _dataFile; - private ArrowToGenericRowConverter _converter; private RootAllocator _allocator; private FileInputStream _fileInputStream; private ArrowFileReader _arrowFileReader; @@ -45,19 +46,18 @@ public class ArrowRecordReader implements RecordReader { private int _nextRowId; private int _currentBatchRowCount; - public ArrowRecordReader() { - } - @Override - public void init(File dataFile, @Nullable Set fieldsToRead, - @Nullable RecordReaderConfig recordReaderConfig) + public void init(File dataFile, @Nullable Set fieldsToRead, @Nullable RecordReaderConfig recordReaderConfig) throws IOException { - _dataFile = dataFile; - _converter = new ArrowToGenericRowConverter(fieldsToRead); + ArrowRecordExtractorConfig extractorConfig = new ArrowRecordExtractorConfig(); long allocatorLimit = ArrowRecordReaderConfig.DEFAULT_ALLOCATOR_LIMIT; if (recordReaderConfig instanceof ArrowRecordReaderConfig) { - allocatorLimit = ((ArrowRecordReaderConfig) recordReaderConfig).getAllocatorLimit(); + ArrowRecordReaderConfig arrowReaderConfig = (ArrowRecordReaderConfig) recordReaderConfig; + allocatorLimit = arrowReaderConfig.getAllocatorLimit(); + extractorConfig.setExtractRawTimeValues(arrowReaderConfig.isExtractRawTimeValues()); } + _extractor.init(fieldsToRead, extractorConfig); + _dataFile = dataFile; _allocator = new RootAllocator(allocatorLimit); openFile(); } @@ -67,6 +67,7 @@ private void openFile() _fileInputStream = new FileInputStream(_dataFile); _arrowFileReader = new ArrowFileReader(_fileInputStream.getChannel(), _allocator); _root = _arrowFileReader.getVectorSchemaRoot(); + _extractor.setReader(_arrowFileReader); _nextRowId = 0; _currentBatchRowCount = 0; advanceToNonEmptyBatch(); @@ -80,6 +81,7 @@ private void advanceToNonEmptyBatch() _hasNextBatch = true; _currentBatchRowCount = rowCount; _nextRowId = 0; + _extractor.prepareBatch(_record); return; } } @@ -94,7 +96,9 @@ public boolean hasNext() { @Override public GenericRow next(GenericRow reuse) throws IOException { - _converter.convertSingleRow(_arrowFileReader, _root, _nextRowId, reuse); + reuse.clear(); + _record.setRowId(_nextRowId); + _extractor.extract(_record, reuse); _nextRowId++; if (_nextRowId >= _currentBatchRowCount) { advanceToNonEmptyBatch(); @@ -126,6 +130,8 @@ private void closeFile() throws IOException { IOException exception = null; + _record.close(); + if (_arrowFileReader != null) { try { _arrowFileReader.close(); diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderConfig.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderConfig.java index 706354f1f340..a316e410a98e 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderConfig.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderConfig.java @@ -21,16 +21,14 @@ import org.apache.pinot.spi.data.readers.RecordReaderConfig; -/** - * Config for {@link ArrowRecordReader}. - */ +/// Config for [ArrowRecordReader]. Carries the Arrow allocator limit plus the +/// [ArrowRecordExtractorConfig] `extractRawTimeValues` flag so the reader can construct the +/// extractor's config at init time. public class ArrowRecordReaderConfig implements RecordReaderConfig { public static final long DEFAULT_ALLOCATOR_LIMIT = 268435456L; // 256MB private long _allocatorLimit = DEFAULT_ALLOCATOR_LIMIT; - - public ArrowRecordReaderConfig() { - } + private boolean _extractRawTimeValues; public long getAllocatorLimit() { return _allocatorLimit; @@ -39,4 +37,12 @@ public long getAllocatorLimit() { public void setAllocatorLimit(long allocatorLimit) { _allocatorLimit = allocatorLimit; } + + public boolean isExtractRawTimeValues() { + return _extractRawTimeValues; + } + + public void setExtractRawTimeValues(boolean extractRawTimeValues) { + _extractRawTimeValues = extractRawTimeValues; + } } diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToGenericRowConverter.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToGenericRowConverter.java deleted file mode 100644 index 1b7b7d00ba5f..000000000000 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/main/java/org/apache/pinot/plugin/inputformat/arrow/ArrowToGenericRowConverter.java +++ /dev/null @@ -1,264 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.plugin.inputformat.arrow; - -import java.sql.Timestamp; -import java.time.LocalDateTime; -import java.time.ZoneOffset; -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import javax.annotation.Nullable; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.dictionary.DictionaryEncoder; -import org.apache.arrow.vector.ipc.ArrowReader; -import org.apache.arrow.vector.util.Text; -import org.apache.pinot.spi.data.readers.BaseRecordExtractor; -import org.apache.pinot.spi.data.readers.GenericRow; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * Utility class for converting Apache Arrow VectorSchemaRoot to Pinot {@code GenericRow}. Processes - * all fields and handles multiple rows from Arrow batch. - */ -public class ArrowToGenericRowConverter { - private static final Logger logger = LoggerFactory.getLogger(ArrowToGenericRowConverter.class); - - @Nullable - private final Set _fieldsToRead; - - /** Default constructor that processes all fields from Arrow batch. */ - public ArrowToGenericRowConverter() { - this(null); - } - - /** - * Constructor that processes only specified fields from Arrow batch. - * - * @param fieldsToRead Set of field names to read. If null or empty, reads all fields. - */ - public ArrowToGenericRowConverter(@Nullable Set fieldsToRead) { - _fieldsToRead = (fieldsToRead == null || fieldsToRead.isEmpty()) ? null : Set.copyOf(fieldsToRead); - } - - /** - * Converts an Arrow VectorSchemaRoot to a Pinot {@code GenericRow}. Processes ALL rows from the - * Arrow batch and stores them as a list using MULTIPLE_RECORDS_KEY. - * - * @param reader ArrowReader containing the data - * @param root Arrow VectorSchemaRoot containing the data - * @param destination Optional destination {@code GenericRow}, will create new if null - * @return {@code GenericRow} containing {@code List} with all converted rows, or null - * if no data available - */ - @Nullable - public GenericRow convert( - ArrowReader reader, VectorSchemaRoot root, GenericRow destination) { - if (root == null) { - logger.warn("Cannot convert null VectorSchemaRoot"); - return null; - } - - if (destination == null) { - destination = new GenericRow(); - } - - int rowCount = root.getRowCount(); - if (rowCount == 0) { - logger.warn("No rows found in Arrow data"); - return destination; - } - - List rows = new ArrayList<>(rowCount); - - // Process all rows from the Arrow batch - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - GenericRow row = convertSingleRow(reader, root, rowIndex); - if (row != null) { - rows.add(row); - } - } - - if (!rows.isEmpty()) { - // Use Pinot's MULTIPLE_RECORDS_KEY to store the list of rows - destination.putValue(GenericRow.MULTIPLE_RECORDS_KEY, rows); - logger.debug("Converted {} rows from Arrow batch", rows.size()); - } - - return destination; - } - - /** - * Converts a single row from Arrow VectorSchemaRoot. - * - * @param reader ArrowReader containing the data - * @param root Arrow VectorSchemaRoot containing the data - * @param rowIndex Index of the row to convert (0-based) - * @return {@code GenericRow} with converted data, or null if row index is invalid - */ - @Nullable - public GenericRow convertSingleRow( - ArrowReader reader, VectorSchemaRoot root, int rowIndex) { - int rowCount = root.getRowCount(); - if (rowIndex < 0 || rowIndex >= rowCount) { - logger.warn("Row index {} is out of bounds [0, {}) for Arrow batch", rowIndex, rowCount); - return null; - } - return convertSingleRow(reader, root, rowIndex, new GenericRow()); - } - - /** - * Converts a single row from Arrow VectorSchemaRoot into the given {@code GenericRow}. - * - * @param reader ArrowReader containing the data - * @param root Arrow VectorSchemaRoot containing the data - * @param rowIndex Index of the row to convert (0-based) - * @param reuse GenericRow to populate with converted data - * @return the populated {@code GenericRow} - */ - public GenericRow convertSingleRow( - ArrowReader reader, VectorSchemaRoot root, int rowIndex, GenericRow reuse) { - reuse.clear(); - int convertedFields = 0; - - // Process all fields in the Arrow schema - for (int i = 0; i < root.getFieldVectors().size(); i++) { - Object value; - - FieldVector fieldVector = root.getFieldVectors().get(i); - String fieldName = fieldVector.getField().getName(); - if (_fieldsToRead != null && !_fieldsToRead.contains(fieldName)) { - continue; - } - try { - if (fieldVector.getField().getDictionary() != null) { - long dictionaryId = fieldVector.getField().getDictionary().getId(); - try (ValueVector realFieldVector = - DictionaryEncoder.decode( - fieldVector, reader.getDictionaryVectors().get(dictionaryId))) { - value = realFieldVector.getObject(rowIndex); - } - } else { - value = fieldVector.getObject(rowIndex); - } - if (value != null) { - // Convert Arrow-specific types to Pinot-compatible types - Object pinotCompatibleValue = convertArrowTypeToPinotCompatible(value); - reuse.putValue(fieldName, pinotCompatibleValue); - convertedFields++; - } else { - reuse.putValue(fieldName, null); - } - } catch (Exception e) { - logger.error("Error extracting value for field: {} at row {}", fieldName, rowIndex, e); - } - } - - logger.debug("Converted {} fields from Arrow row {} to GenericRow", convertedFields, rowIndex); - return reuse; - } - - /** - * Converts Arrow-specific data types to Pinot-compatible types. This method handles the - * incompatibility issues between Arrow's native data types and what Pinot expects. - * - * @param value The raw value from Arrow fieldVector.getObject() - * @return A Pinot-compatible version of the value - */ - @Nullable - private Object convertArrowTypeToPinotCompatible(@Nullable Object value) { - if (value == null) { - return null; - } - - // Handle nested List and Map values, including Arrow MapVector's representation - if (value instanceof List) { - List originalList = (List) value; - if (!originalList.isEmpty()) { - boolean looksLikeMapEntries = true; - boolean sawNonNull = false; - for (Object entryObj : originalList) { - if (entryObj == null) { - continue; - } - sawNonNull = true; - if (!(entryObj instanceof Map)) { - looksLikeMapEntries = false; - break; - } - @SuppressWarnings("unchecked") - Map entryMap = (Map) entryObj; - if (!entryMap.containsKey(MapVector.KEY_NAME)) { - looksLikeMapEntries = false; - break; - } - } - if (looksLikeMapEntries && sawNonNull) { - Map flattened = new LinkedHashMap<>(originalList.size()); - for (Object entryObj : originalList) { - if (entryObj == null) { - continue; - } - @SuppressWarnings("unchecked") - Map entryMap = (Map) entryObj; - Object rawKey = entryMap.get(MapVector.KEY_NAME); - Object rawVal = entryMap.get(MapVector.VALUE_NAME); - Object convertedKey = convertArrowTypeToPinotCompatible(rawKey); - Object convertedVal = convertArrowTypeToPinotCompatible(rawVal); - flattened.put(BaseRecordExtractor.stringifyMapKey(convertedKey), convertedVal); - } - return flattened; - } - } - - List convertedList = new ArrayList<>(originalList.size()); - for (Object element : originalList) { - convertedList.add(convertArrowTypeToPinotCompatible(element)); - } - return convertedList; - } - - // Handle Arrow Text type -> String conversion - if (value instanceof Text) { - // Arrow VarCharVector.getObject() returns Text objects, but Pinot expects String - return value.toString(); - } - - // Handle Arrow LocalDateTime -> java.sql.Timestamp conversion - if (value instanceof LocalDateTime) { - // Arrow TimeStampMilliVector.getObject() returns LocalDateTime, but Pinot expects - // java.sql.Timestamp objects for proper timestamp handling and native support - LocalDateTime dateTime = (LocalDateTime) value; - return Timestamp.from(dateTime.toInstant(ZoneOffset.UTC)); - } - - // Handle other potential Arrow-specific types that might cause issues - - // For primitive types (Integer, Double, Boolean) and other Java standard types, - // Arrow returns standard Java objects that are already Pinot-compatible - return value; - } -} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoderTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoderTest.java index a6b8adc9ca9c..32b8e0cc8ad6 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoderTest.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowMessageDecoderTest.java @@ -18,745 +18,120 @@ */ package org.apache.pinot.plugin.inputformat.arrow; -import com.google.common.collect.Sets; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.pinot.plugin.inputformat.arrow.util.ArrowTestDataUtil; import org.apache.pinot.spi.data.readers.GenericRow; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.testng.annotations.Test; -import static org.testng.Assert.*; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertSame; +/// Tests [ArrowMessageDecoder] — decoder-specific concerns: lifecycle (init / re-init / close / +/// allocator config), error handling for malformed payloads, the empty-batch edge case (returns +/// `null`), single-row batches (fields populated directly into the destination), and multi-row +/// batches ([GenericRow#MULTIPLE_RECORDS_KEY] wrapper). Per-type extraction is covered by +/// [ArrowRecordExtractorTest]. public class ArrowMessageDecoderTest { - private static final Logger LOGGER = LoggerFactory.getLogger(ArrowMessageDecoderTest.class); @Test - public void testArrowMessageDecoderWithDifferentAllocatorLimits() - throws Exception { + public void testAllocatorLimitConfig() throws Exception { + // Custom allocator limit via the props map. ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - // Test with custom allocator limit Map props = new HashMap<>(); props.put(ArrowMessageDecoder.ARROW_ALLOCATOR_LIMIT, "67108864"); // 64MB - - Set fieldsToRead = Sets.newHashSet("field1"); - String topicName = "test-topic-custom"; - - decoder.init(props, fieldsToRead, topicName); + decoder.init(props, Set.of("field1"), "test-topic-custom"); decoder.close(); - // Test with default allocator limit - ArrowMessageDecoder decoder2 = new ArrowMessageDecoder(); - Map props2 = new HashMap<>(); // No allocator limit set - - decoder2.init(props2, fieldsToRead, topicName); - decoder2.close(); - } - - @Test - public void testArrowMessageDecoderMultipleInits() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id"); - String topicName = "test-multiple-init"; - - // Test multiple initializations (should work without issues) - decoder.init(props, fieldsToRead, topicName); - decoder.init(props, fieldsToRead, topicName); - - decoder.close(); + // Default allocator limit when the prop is absent. + ArrowMessageDecoder defaultDecoder = new ArrowMessageDecoder(); + defaultDecoder.init(new HashMap<>(), Set.of("field1"), "test-topic-default"); + defaultDecoder.close(); } @Test - public void testArrowMessageDecodingWithInvalidData() - throws Exception { + public void testReInit() throws Exception { ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "name", "age"); - String topicName = "test-arrow-topic"; - - decoder.init(props, fieldsToRead, topicName); - - // Test various invalid data scenarios - byte[] invalidData1 = "invalid arrow data".getBytes(); - byte[] invalidData2 = new byte[]{1, 2, 3, 4, 5}; - byte[] emptyData = new byte[0]; - - GenericRow destination = new GenericRow(); - - // Should return null for all invalid data types and null - assertNull(decoder.decode(null, destination)); - assertNull(decoder.decode(invalidData1, destination)); - assertNull(decoder.decode(invalidData2, destination)); - assertNull(decoder.decode(emptyData, destination)); - - // Test with null destination - assertNull(decoder.decode(invalidData1, null)); - - // Clean up + decoder.init(props, Set.of("id"), "topic-1"); + decoder.init(props, Set.of("id"), "topic-1"); // re-init must not throw decoder.close(); } @Test - public void testArrowMessageDecoderCloseMultipleTimes() - throws Exception { + public void testCloseIsIdempotent() throws Exception { ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id"); - String topicName = "test-multiple-close"; - - decoder.init(props, fieldsToRead, topicName); - - // Close multiple times should not cause issues + decoder.init(new HashMap<>(), Set.of("id"), "topic-close"); decoder.close(); decoder.close(); decoder.close(); } @Test - public void testArrowMessageDecoderWithArrowDataAndDestination() - throws Exception { + public void testInvalidPayloadReturnsNull() throws Exception { ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "name"); - String topicName = "test-real-arrow-with-destination"; - - decoder.init(props, fieldsToRead, topicName); - - // Create real Arrow IPC data - byte[] realArrowData = ArrowTestDataUtil.createValidArrowIpcData(1); - - // Test with provided destination containing existing data + decoder.init(new HashMap<>(), Set.of("id"), "topic-invalid"); GenericRow destination = new GenericRow(); - destination.putValue("existing_field", "existing_value"); - - GenericRow result = decoder.decode(realArrowData, destination); - - // Should return the same destination object (testing ArrowToGenericRowConverter destination - // handling) - assertSame(destination, result); - - // Should preserve existing data - assertEquals("existing_value", result.getValue("existing_field")); - - // Should contain new converted Arrow data - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(1, rows.size()); - assertEquals(1, rows.get(0).getValue("id")); - assertEquals("name_1", rows.get(0).getValue("name")); - - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithEmptyData() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "name"); - String topicName = "test-empty-arrow-data"; - - decoder.init(props, fieldsToRead, topicName); - - // Test with empty Arrow data (zero batches) - byte[] emptyArrowData = ArrowTestDataUtil.createEmptyArrowIpcData(); - GenericRow result = decoder.decode(emptyArrowData, null); - - // Should handle empty data gracefully - might return null or empty result - // This tests the edge case of zero batches - if (result != null) { - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - if (rows != null) { - assertEquals(0, rows.size()); - } - } - + assertNull(decoder.decode("not arrow ipc".getBytes(), destination)); + assertNull(decoder.decode(new byte[]{1, 2, 3, 4, 5}, destination)); + assertNull(decoder.decode(new byte[0], destination)); + assertNull(decoder.decode("not arrow ipc".getBytes(), null)); decoder.close(); } @Test - public void testArrowMessageDecoderWithMultipleDataTypes() + public void testEmptyBatchReturnsNull() throws Exception { + // A stream with the schema header but no record batches → `decode` returns null. ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "name", "price", "active", "timestamp"); - String topicName = "test-multi-type-arrow-data"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with multiple data types - byte[] multiTypeArrowData = ArrowTestDataUtil.createMultiTypeArrowIpcData(3); - GenericRow result = decoder.decode(multiTypeArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(3, rows.size()); - - // Verify different data types are correctly handled - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - assertEquals("product_1", row0.getValue("name").toString()); - assertEquals(10.99, (Double) row0.getValue("price"), 0.01); - assertEquals(true, row0.getValue("active")); // BitVector returns boolean - assertNotNull(row0.getValue("timestamp")); // Timestamp should be present - - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - assertEquals("product_2", row1.getValue("name").toString()); - assertEquals(15.99, (Double) row1.getValue("price"), 0.01); - assertEquals(false, row1.getValue("active")); - - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithBatchContainingMultipleRows() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "batch_num", "value"); - String topicName = "test-multi-batch-arrow-data"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with multiple batches - but note: ArrowMessageDecoder processes one batch - // per decode() call - // So we test with a single batch containing multiple rows instead - byte[] multiBatchArrowData = - ArrowTestDataUtil.createMultiBatchArrowIpcData(1, 3); // 1 batch, 3 rows - GenericRow result = decoder.decode(multiBatchArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(3, rows.size()); // 1 batch × 3 rows = 3 total rows - - // Verify data from the batch - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - assertEquals(0, row0.getValue("batch_num")); - assertEquals("batch_0_row_0", row0.getValue("value").toString()); - - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - assertEquals(0, row1.getValue("batch_num")); - assertEquals("batch_0_row_1", row1.getValue("value").toString()); - - GenericRow row2 = rows.get(2); - assertEquals(3, row2.getValue("id")); - assertEquals(0, row2.getValue("batch_num")); - assertEquals("batch_0_row_2", row2.getValue("value").toString()); - - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithDictionaryEncodedData() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "category", "price"); - String topicName = "test-dictionary-encoded-arrow-data"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with real dictionary encoding - byte[] dictionaryArrowData = ArrowTestDataUtil.createDictionaryEncodedArrowIpcData(8); - GenericRow result = decoder.decode(dictionaryArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(8, rows.size()); - - // Verify dictionary-encoded values are properly decoded by ArrowToGenericRowConverter - // Dictionary: id=1 -> "Electronics", id=2 -> "Books", id=3 -> "Clothing", id=4 -> "Home" - // Data cycles through indices 0,1,2,3,0,1,2,3 which should be resolved to string values - - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - assertEquals("Electronics", row0.getValue("category")); - assertEquals(19.99, (Double) row0.getValue("price"), 0.01); - - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - assertEquals("Books", row1.getValue("category")); - assertEquals(29.99, (Double) row1.getValue("price"), 0.01); - - GenericRow row2 = rows.get(2); - assertEquals(3, row2.getValue("id")); - assertEquals("Clothing", row2.getValue("category")); - assertEquals(39.99, (Double) row2.getValue("price"), 0.01); - - GenericRow row3 = rows.get(3); - assertEquals(4, row3.getValue("id")); - assertEquals("Home", row3.getValue("category")); - assertEquals(49.99, (Double) row3.getValue("price"), 0.01); - - // Verify cycling continues - row 4 should have same category as row 0 - GenericRow row4 = rows.get(4); - assertEquals(5, row4.getValue("id")); - assertEquals("Electronics", row4.getValue("category")); - assertEquals(59.99, (Double) row4.getValue("price"), 0.01); - + decoder.init(new HashMap<>(), Set.of("id", "name"), "topic-empty"); + assertNull(decoder.decode(ArrowTestDataUtils.createEmptyArrowIpcData(), null)); decoder.close(); } @Test - public void testArrowDataTypeCompatibility() + public void testSingleRowFillsDestinationDirectly() throws Exception { + // 1-row batch → fields populated directly into the destination, no `MULTIPLE_RECORDS_KEY` + // wrapper. Pre-existing values on the destination are preserved. ArrowMessageDecoder decoder = new ArrowMessageDecoder(); + decoder.init(new HashMap<>(), Set.of("id", "name"), "topic-single-row"); - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "name", "price", "active", "timestamp"); - String topicName = "test-data-type-compatibility"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with multiple data types to verify compatibility - byte[] multiTypeArrowData = ArrowTestDataUtil.createMultiTypeArrowIpcData(3); - GenericRow result = decoder.decode(multiTypeArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(3, rows.size()); - - // Check the actual data types returned by Arrow and verify Pinot compatibility - GenericRow row0 = rows.get(0); - - // Verify each field's type and compatibility - Object idValue = row0.getValue("id"); - assertNotNull(idValue, "ID should not be null"); - assertTrue(idValue instanceof Integer, "ID should be Integer compatible"); - - Object nameValue = row0.getValue("name"); - assertNotNull(nameValue, "Name should not be null"); - // After conversion, Arrow Text should be converted to String for Pinot compatibility - assertTrue(nameValue instanceof String, "Name should be String after conversion"); - assertEquals("product_1", nameValue); - LOGGER.info("Arrow name field successfully converted to String: {}", nameValue); - - Object priceValue = row0.getValue("price"); - assertNotNull(priceValue, "Price should not be null"); - assertTrue(priceValue instanceof Double, "Price should be Double compatible"); - - Object activeValue = row0.getValue("active"); - assertNotNull(activeValue, "Active should not be null"); - // BitVector.getObject() returns Boolean - assertTrue(activeValue instanceof Boolean, "Active should be Boolean compatible"); - - Object timestampValue = row0.getValue("timestamp"); - assertNotNull(timestampValue, "Timestamp should not be null"); - // After conversion, Arrow LocalDateTime should be converted to java.sql.Timestamp for Pinot - // compatibility - assertTrue( - timestampValue instanceof java.sql.Timestamp, - "Timestamp should be java.sql.Timestamp after conversion"); - java.sql.Timestamp ts = (java.sql.Timestamp) timestampValue; - assertTrue(ts.getTime() > 0, "Timestamp should be a positive value"); - LOGGER.info( - "Arrow timestamp field successfully converted to java.sql.Timestamp: {}", timestampValue); - - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithListVectors() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "numbers", "tags"); - String topicName = "test-list-vectors"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with List vectors - byte[] listArrowData = ArrowTestDataUtil.createListArrowIpcData(3); - GenericRow result = decoder.decode(listArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(3, rows.size()); - - // Verify first row - should have 1 number and 2 tags - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - Object numbersValue0 = row0.getValue("numbers"); - assertNotNull(numbersValue0, "Numbers should not be null"); - assertTrue(numbersValue0 instanceof List); - @SuppressWarnings("unchecked") - List numbersList0 = (List) numbersValue0; - assertEquals(1, numbersList0.size()); - assertEquals(10, numbersList0.get(0)); - - Object tagsValue0 = row0.getValue("tags"); - assertNotNull(tagsValue0, "Tags should not be null"); - assertTrue(tagsValue0 instanceof List); - @SuppressWarnings("unchecked") - List tagsList0 = (List) tagsValue0; - assertEquals(2, tagsList0.size()); - assertEquals("tag_0_0", tagsList0.get(0).toString()); - assertEquals("tag_0_1", tagsList0.get(1).toString()); - - // Verify second row - should have 2 numbers and 2 tags - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - Object numbersValue1 = row1.getValue("numbers"); - assertNotNull(numbersValue1); - @SuppressWarnings("unchecked") - List numbersList1 = (List) numbersValue1; - assertEquals(2, numbersList1.size()); - assertEquals(20, numbersList1.get(0)); - assertEquals(21, numbersList1.get(1)); - - // Verify third row - should have 3 numbers - GenericRow row2 = rows.get(2); - assertEquals(3, row2.getValue("id")); - Object numbersValue2 = row2.getValue("numbers"); - @SuppressWarnings("unchecked") - List numbersList2 = (List) numbersValue2; - assertEquals(3, numbersList2.size()); - assertEquals(30, numbersList2.get(0)); - assertEquals(31, numbersList2.get(1)); - assertEquals(32, numbersList2.get(2)); - - LOGGER.info("List vector test completed successfully with {} rows", rows.size()); - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithStructVectors() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "person"); - String topicName = "test-struct-vectors"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with Struct vectors - byte[] structArrowData = ArrowTestDataUtil.createStructArrowIpcData(2); - GenericRow result = decoder.decode(structArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(2, rows.size()); - - // Verify first row with nested struct - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - Object personValue0 = row0.getValue("person"); - assertNotNull(personValue0); - assertTrue(personValue0 instanceof Map); - @SuppressWarnings("unchecked") - Map personMap0 = (Map) personValue0; - assertEquals("Person_1", personMap0.get("name").toString()); - assertEquals(25, personMap0.get("age")); - @SuppressWarnings("unchecked") - Map address0 = (Map) personMap0.get("address"); - assertEquals("1 Main St", address0.get("street").toString()); - assertEquals("City_1", address0.get("city").toString()); - - // Verify second row - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - Object personValue1 = row1.getValue("person"); - assertNotNull(personValue1); - assertTrue(personValue1 instanceof Map); - @SuppressWarnings("unchecked") - Map personMap1 = (Map) personValue1; - assertEquals("Person_2", personMap1.get("name").toString()); - assertEquals(26, personMap1.get("age")); - @SuppressWarnings("unchecked") - Map address1 = (Map) personMap1.get("address"); - assertEquals("2 Main St", address1.get("street").toString()); - assertEquals("City_2", address1.get("city").toString()); - - LOGGER.info("Struct vector test completed successfully with {} rows", rows.size()); - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithMapVectors() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "metadata"); - String topicName = "test-map-vectors"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with Map vectors - byte[] mapArrowData = ArrowTestDataUtil.createMapArrowIpcData(2); - GenericRow result = decoder.decode(mapArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(2, rows.size()); - - // Verify first row with map data - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - Object metadataValue0 = row0.getValue("metadata"); - assertNotNull(metadataValue0); - assertTrue(metadataValue0 instanceof Map); - @SuppressWarnings("unchecked") - Map meta0 = (Map) metadataValue0; - assertTrue(meta0.values().contains(100)); - assertTrue(meta0.values().contains(101)); - - // Verify second row - should have 3 entries (2 + (1%2) = 3) - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - Object metadataValue1 = row1.getValue("metadata"); - assertNotNull(metadataValue1); - assertTrue(metadataValue1 instanceof Map); - @SuppressWarnings("unchecked") - Map meta1 = (Map) metadataValue1; - assertTrue(meta1.values().contains(200)); - assertTrue(meta1.values().contains(201)); - assertTrue(meta1.values().contains(202)); - - LOGGER.info("Map vector test completed successfully with {} rows", rows.size()); + GenericRow destination = new GenericRow(); + destination.putValue("existing_field", "existing_value"); + GenericRow result = decoder.decode(ArrowTestDataUtils.createValidArrowIpcData(1), destination); + assertSame(result, destination); + assertEquals(result.getValue("existing_field"), "existing_value"); + assertEquals(result.getValue("id"), 1); + assertEquals(result.getValue("name"), "name_1"); + assertNull(result.getValue(GenericRow.MULTIPLE_RECORDS_KEY)); decoder.close(); } @Test - public void testArrowMessageDecoderWithNestedMapValues() + public void testMultiRowBatch() throws Exception { + // A single Arrow batch with N rows produces N `GenericRow`s under `MULTIPLE_RECORDS_KEY`. ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "metadata"); - String topicName = "test-nested-map-values"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with Map values that are themselves Maps - byte[] nestedMapArrowData = ArrowTestDataUtil.createNestedMapArrowIpcData(2); - GenericRow result = decoder.decode(nestedMapArrowData, null); + decoder.init(new HashMap<>(), Set.of("id", "batch_num", "value"), "topic-multi-row"); + GenericRow result = + decoder.decode(ArrowTestDataUtils.createMultiBatchArrowIpcData(1, 3), null); assertNotNull(result); - @SuppressWarnings("unchecked") + //noinspection unchecked List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); assertNotNull(rows); - assertEquals(2, rows.size()); - - // Verify first row: metadata is a Map> - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - Object metadataValue0 = row0.getValue("metadata"); - assertNotNull(metadataValue0); - assertTrue(metadataValue0 instanceof Map); - @SuppressWarnings("unchecked") - Map outer0 = (Map) metadataValue0; - assertTrue(outer0.size() >= 2); - for (Object innerMapObj : outer0.values()) { - assertTrue(innerMapObj instanceof Map); - @SuppressWarnings("unchecked") - Map inner = (Map) innerMapObj; - assertTrue(inner.size() >= 2); - // Values should be integers from generator - for (Object v : inner.values()) { - assertTrue(v instanceof Integer); - } - } - - // Verify second row similarly - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - Object metadataValue1 = row1.getValue("metadata"); - assertNotNull(metadataValue1); - assertTrue(metadataValue1 instanceof Map); - @SuppressWarnings("unchecked") - Map outer1 = (Map) metadataValue1; - assertTrue(outer1.size() >= 2); - boolean sawThreeInner = false; - for (Object innerMapObj : outer1.values()) { - assertTrue(innerMapObj instanceof Map); - @SuppressWarnings("unchecked") - Map inner = (Map) innerMapObj; - if (inner.size() == 3) { - sawThreeInner = true; - } + assertEquals(rows.size(), 3); + for (int i = 0; i < 3; i++) { + GenericRow row = rows.get(i); + assertEquals(row.getValue("id"), i + 1); + assertEquals(row.getValue("batch_num"), 0); + assertEquals(row.getValue("value"), "batch_0_row_" + i); } - assertTrue(sawThreeInner); - - decoder.close(); - } - - @Test - public void testArrowMessageDecoderWithNestedListStruct() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "items"); - String topicName = "test-nested-list-struct"; - - decoder.init(props, fieldsToRead, topicName); - - // Create Arrow data with nested List of Structs - byte[] nestedArrowData = ArrowTestDataUtil.createNestedListStructArrowIpcData(3); - GenericRow result = decoder.decode(nestedArrowData, null); - - assertNotNull(result); - @SuppressWarnings("unchecked") - List rows = (List) result.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(rows); - assertEquals(3, rows.size()); - - // Verify first row - should have 1 item (1 + (0%3) = 1) - GenericRow row0 = rows.get(0); - assertEquals(1, row0.getValue("id")); - Object itemsValue0 = row0.getValue("items"); - assertNotNull(itemsValue0); - assertTrue(itemsValue0 instanceof List); - @SuppressWarnings("unchecked") - List items0 = (List) itemsValue0; - assertEquals(1, items0.size()); - @SuppressWarnings("unchecked") - Map item00 = (Map) items0.get(0); - assertEquals("item_0_0", item00.get("item_name").toString()); - assertEquals(10.99, (Double) item00.get("item_price"), 0.01); - - // Verify second row - should have 2 items (1 + (1%3) = 2) - GenericRow row1 = rows.get(1); - assertEquals(2, row1.getValue("id")); - Object itemsValue1 = row1.getValue("items"); - assertNotNull(itemsValue1); - @SuppressWarnings("unchecked") - List items1 = (List) itemsValue1; - assertEquals(2, items1.size()); - @SuppressWarnings("unchecked") - Map item10 = (Map) items1.get(0); - assertEquals("item_1_0", item10.get("item_name").toString()); - assertEquals(15.99, (Double) item10.get("item_price"), 0.01); - @SuppressWarnings("unchecked") - Map item11 = (Map) items1.get(1); - assertEquals("item_1_1", item11.get("item_name").toString()); - assertEquals(16.99, (Double) item11.get("item_price"), 0.01); - - // Verify third row - should have 3 items (1 + (2%3) = 3) - GenericRow row2 = rows.get(2); - assertEquals(3, row2.getValue("id")); - Object itemsValue2 = row2.getValue("items"); - assertNotNull(itemsValue2); - @SuppressWarnings("unchecked") - List items2 = (List) itemsValue2; - assertEquals(3, items2.size()); - - LOGGER.info("Nested List-Struct test completed successfully with {} rows", rows.size()); - decoder.close(); - } - - @Test - public void testArrowNestedStructureCompatibilityWithPinot() - throws Exception { - ArrowMessageDecoder decoder = new ArrowMessageDecoder(); - - Map props = new HashMap<>(); - Set fieldsToRead = Sets.newHashSet("id", "numbers", "person", "metadata", "items"); - String topicName = "test-nested-compatibility"; - - decoder.init(props, fieldsToRead, topicName); - - // Test each nested structure type individually for compatibility - // Test List compatibility - byte[] listData = ArrowTestDataUtil.createListArrowIpcData(1); - GenericRow listResult = decoder.decode(listData, null); - assertNotNull(listResult, "List data should be decodable"); - - // Test Struct compatibility - byte[] structData = ArrowTestDataUtil.createStructArrowIpcData(1); - GenericRow structResult = decoder.decode(structData, null); - assertNotNull(structResult, "Struct data should be decodable"); - - // Test Map compatibility - byte[] mapData = ArrowTestDataUtil.createMapArrowIpcData(1); - GenericRow mapResult = decoder.decode(mapData, null); - assertNotNull(mapResult, "Map data should be decodable"); - - // Test complex nested structures - byte[] nestedData = ArrowTestDataUtil.createNestedListStructArrowIpcData(1); - GenericRow nestedResult = decoder.decode(nestedData, null); - assertNotNull(nestedResult, "Nested List-Struct data should be decodable"); - - // Verify that all simulated nested structures produce valid GenericRow objects - @SuppressWarnings("unchecked") - List listRows = - (List) listResult.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(listRows, "List result should contain rows"); - assertTrue(listRows.size() > 0, "List result should have at least one row"); - - // Verify nested list data is accessible - GenericRow firstListRow = listRows.get(0); - assertNotNull(firstListRow.getValue("numbers"), "List row should have numbers"); - - @SuppressWarnings("unchecked") - List structRows = - (List) structResult.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(structRows, "Struct result should contain rows"); - assertTrue(structRows.size() > 0, "Struct result should have at least one row"); - - // Verify struct data is accessible - GenericRow firstStructRow = structRows.get(0); - assertNotNull(firstStructRow.getValue("person"), "Struct row should have person"); - - @SuppressWarnings("unchecked") - List mapRows = - (List) mapResult.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(mapRows, "Map result should contain rows"); - assertTrue(mapRows.size() > 0, "Map result should have at least one row"); - - // Verify map data is accessible - GenericRow firstMapRow = mapRows.get(0); - assertNotNull(firstMapRow.getValue("metadata"), "Map row should have metadata"); - - @SuppressWarnings("unchecked") - List nestedRows = - (List) nestedResult.getValue(GenericRow.MULTIPLE_RECORDS_KEY); - assertNotNull(nestedRows, "Nested result should contain rows"); - assertTrue(nestedRows.size() > 0, "Nested result should have at least one row"); - - // Verify nested list-struct data is accessible - GenericRow firstNestedRow = nestedRows.get(0); - assertNotNull(firstNestedRow.getValue("items"), "Nested row should have items"); - - LOGGER.info( - "All nested structure types are compatible with ArrowMessageDecoder and produce valid GenericRow objects"); decoder.close(); } } diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorTest.java new file mode 100644 index 000000000000..cf45f2a61d81 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordExtractorTest.java @@ -0,0 +1,824 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.sql.Timestamp; +import java.time.Duration; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.Period; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import javax.annotation.Nullable; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertSame; + + +/// Tests [ArrowRecordExtractor] — see its class Javadoc for the per-type contract. Each test builds a +/// single-column [VectorSchemaRoot], roundtrips it through an Arrow IPC stream so a real +/// [ArrowStreamReader] drives [ArrowRecordExtractor#setReader], extracts row 0, and asserts the value. +public class ArrowRecordExtractorTest { + + private static final String COLUMN = "col"; + private RootAllocator _allocator; + + @BeforeMethod + public void setUp() { + _allocator = new RootAllocator(); + } + + @AfterMethod + public void tearDown() { + _allocator.close(); + } + + // === Scalars (default contract mode) === + + @Test + public void testBoolean() throws IOException { + Field field = field(new ArrowType.Bool()); + assertEquals(extract(field, v -> { + ((BitVector) v).setSafe(0, 1); + v.setValueCount(1); + }), true); + } + + @Test + public void testTinyIntWidenedToInteger() throws IOException { + Field field = field(new ArrowType.Int(8, true)); + assertEquals(extract(field, v -> { + ((TinyIntVector) v).setSafe(0, 42); + v.setValueCount(1); + }), 42); + } + + @Test + public void testSmallIntWidenedToInteger() throws IOException { + Field field = field(new ArrowType.Int(16, true)); + assertEquals(extract(field, v -> { + ((SmallIntVector) v).setSafe(0, 1234); + v.setValueCount(1); + }), 1234); + } + + @Test + public void testIntPreserved() throws IOException { + Field field = field(new ArrowType.Int(32, true)); + assertEquals(extract(field, v -> { + ((IntVector) v).setSafe(0, 100_000); + v.setValueCount(1); + }), 100_000); + } + + @Test + public void testBigIntPreserved() throws IOException { + Field field = field(new ArrowType.Int(64, true)); + assertEquals(extract(field, v -> { + ((BigIntVector) v).setSafe(0, 1_588_469_340_000L); + v.setValueCount(1); + }), 1_588_469_340_000L); + } + + @Test + public void testUInt1WidenedToInteger() throws IOException { + Field field = field(new ArrowType.Int(8, false)); + assertEquals(extract(field, v -> { + ((UInt1Vector) v).setSafe(0, 200); + v.setValueCount(1); + }), 200); + } + + @Test + public void testUInt2CharacterWidenedToInteger() throws IOException { + Field field = field(new ArrowType.Int(16, false)); + Object result = extract(field, v -> { + ((UInt2Vector) v).setSafe(0, 50_000); + v.setValueCount(1); + }); + assertEquals(result, 50_000); + assertSame(result.getClass(), Integer.class); + } + + @Test + public void testUInt4Preserved() throws IOException { + Field field = field(new ArrowType.Int(32, false)); + assertEquals(extract(field, v -> { + ((UInt4Vector) v).setSafe(0, 100_000); + v.setValueCount(1); + }), 100_000); + } + + @Test + public void testUInt8Preserved() throws IOException { + Field field = field(new ArrowType.Int(64, false)); + assertEquals(extract(field, v -> { + ((UInt8Vector) v).setSafe(0, 1234567890L); + v.setValueCount(1); + }), 1234567890L); + } + + @Test + public void testFloat() throws IOException { + Field field = field(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); + assertEquals(extract(field, v -> { + ((Float4Vector) v).setSafe(0, 0.5f); + v.setValueCount(1); + }), 0.5f); + } + + @Test + public void testDouble() throws IOException { + Field field = field(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); + assertEquals(extract(field, v -> { + ((Float8Vector) v).setSafe(0, 1.5); + v.setValueCount(1); + }), 1.5); + } + + @Test + public void testDecimal() throws IOException { + Field field = field(new ArrowType.Decimal(10, 2, 128)); + BigDecimal expected = new BigDecimal("123.45"); + assertEquals(extract(field, v -> { + ((DecimalVector) v).setSafe(0, expected); + v.setValueCount(1); + }), expected); + } + + @Test + public void testUtf8() throws IOException { + Field field = field(new ArrowType.Utf8()); + Object result = extract(field, v -> { + ((VarCharVector) v).setSafe(0, "hello".getBytes()); + v.setValueCount(1); + }); + assertEquals(result, "hello"); + assertSame(result.getClass(), String.class); + } + + @Test + public void testLargeUtf8() throws IOException { + Field field = field(new ArrowType.LargeUtf8()); + Object result = extract(field, v -> { + ((LargeVarCharVector) v).setSafe(0, "world".getBytes()); + v.setValueCount(1); + }); + assertEquals(result, "world"); + assertSame(result.getClass(), String.class); + } + + @Test + public void testBinary() throws IOException { + Field field = field(new ArrowType.Binary()); + byte[] bytes = {1, 2, 3, 4}; + Object result = extract(field, v -> { + ((VarBinaryVector) v).setSafe(0, bytes); + v.setValueCount(1); + }); + assertEquals((byte[]) result, bytes); + } + + @Test + public void testLargeBinary() throws IOException { + Field field = field(new ArrowType.LargeBinary()); + byte[] bytes = {5, 6, 7}; + Object result = extract(field, v -> { + ((LargeVarBinaryVector) v).setSafe(0, bytes); + v.setValueCount(1); + }); + assertEquals((byte[]) result, bytes); + } + + @Test + public void testFixedSizeBinary() throws IOException { + Field field = field(new ArrowType.FixedSizeBinary(4)); + byte[] bytes = {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}; + Object result = extract(field, v -> { + ((FixedSizeBinaryVector) v).setSafe(0, bytes); + v.setValueCount(1); + }); + assertEquals((byte[]) result, bytes); + } + + @Test + public void testNullType() throws IOException { + Field field = field(new ArrowType.Null()); + assertNull(extract(field, v -> { + ((NullVector) v).setValueCount(1); + })); + } + + // === Temporal (default contract mode) === + + @Test + public void testDateDay() throws IOException { + Field field = field(new ArrowType.Date(DateUnit.DAY)); + assertEquals(extract(field, v -> { + ((DateDayVector) v).setSafe(0, 19_000); // 2022-01-08 + v.setValueCount(1); + }), LocalDate.ofEpochDay(19_000)); + } + + @Test + public void testDateMilli() throws IOException { + Field field = field(new ArrowType.Date(DateUnit.MILLISECOND)); + long midnightMillis = 19_000L * 86_400_000L; + assertEquals(extract(field, v -> { + ((DateMilliVector) v).setSafe(0, midnightMillis); + v.setValueCount(1); + }), LocalDate.ofEpochDay(19_000)); + } + + @Test + public void testTimeSecond() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.SECOND, 32)); + assertEquals(extract(field, v -> { + ((TimeSecVector) v).setSafe(0, 3661); // 01:01:01 + v.setValueCount(1); + }), LocalTime.of(1, 1, 1)); + } + + @Test + public void testTimeMilli() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + int midnightOffsetMillis = 3_661_500; + assertEquals(extract(field, v -> { + ((TimeMilliVector) v).setSafe(0, midnightOffsetMillis); + v.setValueCount(1); + }), LocalTime.of(1, 1, 1, 500_000_000)); + } + + @Test + public void testTimeMicro() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.MICROSECOND, 64)); + long micros = 3_661_000_500L; // 01:01:01.0005 + assertEquals(extract(field, v -> { + ((TimeMicroVector) v).setSafe(0, micros); + v.setValueCount(1); + }), LocalTime.of(1, 1, 1, 500_000)); + } + + @Test + public void testTimeNano() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.NANOSECOND, 64)); + long nanos = 3_661_000_000_007L; // 01:01:01.000000007 + assertEquals(extract(field, v -> { + ((TimeNanoVector) v).setSafe(0, nanos); + v.setValueCount(1); + }), LocalTime.of(1, 1, 1, 7)); + } + + @Test + public void testTimestampSecondNoTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.SECOND, null)); + long sec = 1_700_000_000L; + assertEquals(extract(field, v -> { + ((TimeStampSecVector) v).setSafe(0, sec); + v.setValueCount(1); + }), new Timestamp(sec * 1000L)); + } + + @Test + public void testTimestampMilliNoTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)); + long ms = 1_700_000_000_500L; + assertEquals(extract(field, v -> { + ((TimeStampMilliVector) v).setSafe(0, ms); + v.setValueCount(1); + }), new Timestamp(ms)); + } + + @Test + public void testTimestampMicroNoTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); + long micros = 1_700_000_000_000_007L; + Timestamp expected = new Timestamp(1_700_000_000_000L); + expected.setNanos(7_000); + assertEquals(extract(field, v -> { + ((TimeStampMicroVector) v).setSafe(0, micros); + v.setValueCount(1); + }), expected); + } + + @Test + public void testTimestampNanoNoTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)); + long nanos = 1_700_000_000_000_000_007L; + Timestamp expected = new Timestamp(1_700_000_000_000L); + expected.setNanos(7); + assertEquals(extract(field, v -> { + ((TimeStampNanoVector) v).setSafe(0, nanos); + v.setValueCount(1); + }), expected); + } + + @Test + public void testTimestampSecondWithTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.SECOND, "UTC")); + long sec = 1_700_000_000L; + assertEquals(extract(field, v -> { + ((TimeStampSecTZVector) v).setSafe(0, sec); + v.setValueCount(1); + }), new Timestamp(sec * 1000L)); + } + + @Test + public void testTimestampMilliWithTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")); + long ms = 1_700_000_000_500L; + assertEquals(extract(field, v -> { + ((TimeStampMilliTZVector) v).setSafe(0, ms); + v.setValueCount(1); + }), new Timestamp(ms)); + } + + @Test + public void testTimestampMicroWithTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")); + long micros = 1_700_000_000_000_007L; + Timestamp expected = new Timestamp(1_700_000_000_000L); + expected.setNanos(7_000); + assertEquals(extract(field, v -> { + ((TimeStampMicroTZVector) v).setSafe(0, micros); + v.setValueCount(1); + }), expected); + } + + @Test + public void testTimestampNanoWithTZ() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")); + long nanos = 1_700_000_000_000_000_007L; + Timestamp expected = new Timestamp(1_700_000_000_000L); + expected.setNanos(7); + assertEquals(extract(field, v -> { + ((TimeStampNanoTZVector) v).setSafe(0, nanos); + v.setValueCount(1); + }), expected); + } + + // === Temporal (raw mode — extractRawTimeValues = true) === + + @Test + public void testDateDayRaw() throws IOException { + Field field = field(new ArrowType.Date(DateUnit.DAY)); + assertEquals(extractRaw(field, v -> { + ((DateDayVector) v).setSafe(0, 19_000); + v.setValueCount(1); + }), 19_000); + } + + @Test + public void testDateMilliRaw() throws IOException { + // Raw mode normalizes to int days regardless of underlying unit. + Field field = field(new ArrowType.Date(DateUnit.MILLISECOND)); + long midnightMillis = 19_000L * 86_400_000L; + assertEquals(extractRaw(field, v -> { + ((DateMilliVector) v).setSafe(0, midnightMillis); + v.setValueCount(1); + }), 19_000); + } + + @Test + public void testTimeSecondRaw() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.SECOND, 32)); + assertEquals(extractRaw(field, v -> { + ((TimeSecVector) v).setSafe(0, 3661); + v.setValueCount(1); + }), 3661); + } + + @Test + public void testTimeMilliRaw() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + int millis = 3_661_500; + assertEquals(extractRaw(field, v -> { + ((TimeMilliVector) v).setSafe(0, millis); + v.setValueCount(1); + }), millis); + } + + @Test + public void testTimeMicroRaw() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.MICROSECOND, 64)); + long micros = 3_661_000_500L; + assertEquals(extractRaw(field, v -> { + ((TimeMicroVector) v).setSafe(0, micros); + v.setValueCount(1); + }), micros); + } + + @Test + public void testTimeNanoRaw() throws IOException { + Field field = field(new ArrowType.Time(TimeUnit.NANOSECOND, 64)); + long nanos = 3_661_000_000_007L; + assertEquals(extractRaw(field, v -> { + ((TimeNanoVector) v).setSafe(0, nanos); + v.setValueCount(1); + }), nanos); + } + + @Test + public void testTimestampSecondNoTZRaw() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.SECOND, null)); + long sec = 1_700_000_000L; + assertEquals(extractRaw(field, v -> { + ((TimeStampSecVector) v).setSafe(0, sec); + v.setValueCount(1); + }), sec); + } + + @Test + public void testTimestampMilliNoTZRaw() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)); + long ms = 1_700_000_000_500L; + assertEquals(extractRaw(field, v -> { + ((TimeStampMilliVector) v).setSafe(0, ms); + v.setValueCount(1); + }), ms); + } + + @Test + public void testTimestampMicroNoTZRaw() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); + long micros = 1_700_000_000_000_007L; + assertEquals(extractRaw(field, v -> { + ((TimeStampMicroVector) v).setSafe(0, micros); + v.setValueCount(1); + }), micros); + } + + @Test + public void testTimestampNanoNoTZRaw() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)); + long nanos = 1_700_000_000_000_000_007L; + assertEquals(extractRaw(field, v -> { + ((TimeStampNanoVector) v).setSafe(0, nanos); + v.setValueCount(1); + }), nanos); + } + + @Test + public void testTimestampMilliWithTZRaw() throws IOException { + Field field = field(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")); + long ms = 1_700_000_000_500L; + assertEquals(extractRaw(field, v -> { + ((TimeStampMilliTZVector) v).setSafe(0, ms); + v.setValueCount(1); + }), ms); + } + + // === Interval / Duration === + + @Test + public void testIntervalDay() throws IOException { + Field field = field(new ArrowType.Interval(IntervalUnit.DAY_TIME)); + Object result = extract(field, v -> { + ((IntervalDayVector) v).setSafe(0, 1, 5_000); // 1 day + 5s + v.setValueCount(1); + }); + assertEquals(result, Duration.ofDays(1).plusSeconds(5).toString()); + } + + @Test + public void testIntervalYear() throws IOException { + Field field = field(new ArrowType.Interval(IntervalUnit.YEAR_MONTH)); + Object result = extract(field, v -> { + ((IntervalYearVector) v).setSafe(0, 14); // 14 months — Arrow returns Period.ofMonths (un-normalized) + v.setValueCount(1); + }); + assertEquals(result, Period.ofMonths(14).toString()); + } + + @Test + public void testDuration() throws IOException { + Field field = field(new ArrowType.Duration(TimeUnit.MILLISECOND)); + Object result = extract(field, v -> { + ((DurationVector) v).setSafe(0, 90_000L); // 1m30s + v.setValueCount(1); + }); + assertEquals(result, Duration.ofSeconds(90).toString()); + } + + // === Complex === + + @Test + public void testListOfInt() throws IOException { + Field elementField = new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field listField = new Field(COLUMN, FieldType.nullable(new ArrowType.List()), List.of(elementField)); + Object result = extract(listField, v -> { + ListVector lv = (ListVector) v; + lv.allocateNew(); + IntVector child = (IntVector) lv.getDataVector(); + lv.startNewValue(0); + child.setSafe(0, 10); + child.setSafe(1, 20); + child.setSafe(2, 30); + lv.endValue(0, 3); + child.setValueCount(3); + lv.setValueCount(1); + }); + assertEquals((Object[]) result, new Object[]{10, 20, 30}); + } + + @Test + public void testListOfString() throws IOException { + Field elementField = new Field("$data$", FieldType.nullable(new ArrowType.Utf8()), null); + Field listField = new Field(COLUMN, FieldType.nullable(new ArrowType.List()), List.of(elementField)); + Object result = extract(listField, v -> { + ListVector lv = (ListVector) v; + lv.allocateNew(); + VarCharVector child = (VarCharVector) lv.getDataVector(); + lv.startNewValue(0); + child.setSafe(0, "a".getBytes()); + child.setSafe(1, "b".getBytes()); + lv.endValue(0, 2); + child.setValueCount(2); + lv.setValueCount(1); + }); + assertEquals((Object[]) result, new Object[]{"a", "b"}); + } + + @Test + public void testStruct() throws IOException { + Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); + Field ageField = new Field("age", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field structField = new Field(COLUMN, FieldType.nullable(new ArrowType.Struct()), List.of(nameField, ageField)); + Object result = extract(structField, v -> { + StructVector sv = (StructVector) v; + sv.allocateNew(); + ((VarCharVector) sv.getChild("name")).setSafe(0, "Alice".getBytes()); + ((IntVector) sv.getChild("age")).setSafe(0, 30); + sv.setIndexDefined(0); + sv.setValueCount(1); + }); + @SuppressWarnings("unchecked") + Map map = (Map) result; + assertEquals(map.get("name"), "Alice"); + assertEquals(map.get("age"), 30); + } + + @Test + public void testMap() throws IOException { + Field keyField = new Field(MapVector.KEY_NAME, FieldType.notNullable(new ArrowType.Utf8()), null); + Field valField = new Field(MapVector.VALUE_NAME, FieldType.nullable(new ArrowType.Int(32, true)), null); + Field entriesField = + new Field(MapVector.DATA_VECTOR_NAME, FieldType.notNullable(new ArrowType.Struct()), + List.of(keyField, valField)); + Field mapField = + new Field(COLUMN, FieldType.nullable(new ArrowType.Map(false)), List.of(entriesField)); + Object result = extract(mapField, v -> { + MapVector mv = (MapVector) v; + mv.allocateNew(); + StructVector entries = (StructVector) mv.getDataVector(); + VarCharVector keys = (VarCharVector) entries.getChild(MapVector.KEY_NAME); + IntVector vals = (IntVector) entries.getChild(MapVector.VALUE_NAME); + mv.startNewValue(0); + keys.setSafe(0, "k1".getBytes()); + vals.setSafe(0, 100); + entries.setIndexDefined(0); + keys.setSafe(1, "k2".getBytes()); + vals.setSafe(1, 200); + entries.setIndexDefined(1); + mv.endValue(0, 2); + keys.setValueCount(2); + vals.setValueCount(2); + entries.setValueCount(2); + mv.setValueCount(1); + }); + @SuppressWarnings("unchecked") + Map map = (Map) result; + assertEquals(map.get("k1"), 100); + assertEquals(map.get("k2"), 200); + } + + // === Null value (column with nullable type whose row 0 is null) === + + @Test + public void testNullIntValue() throws IOException { + Field field = field(new ArrowType.Int(32, true)); + assertNull(extract(field, v -> { + v.setValueCount(1); + })); + } + + @Test + public void testNullStringValue() throws IOException { + Field field = field(new ArrowType.Utf8()); + assertNull(extract(field, v -> { + v.setValueCount(1); + })); + } + + // === Dictionary-encoded === + + @Test + public void testDictionaryEncodedString() throws IOException { + DictionaryEncoding encoding = new DictionaryEncoding(1L, false, new ArrowType.Int(32, true)); + + // Dictionary values: ["Alpha", "Beta", "Gamma"]. + VarCharVector dictVec = new VarCharVector("dict", _allocator); + dictVec.allocateNew(); + dictVec.setSafe(0, "Alpha".getBytes()); + dictVec.setSafe(1, "Beta".getBytes()); + dictVec.setSafe(2, "Gamma".getBytes()); + dictVec.setValueCount(3); + Dictionary dictionary = new Dictionary(dictVec, encoding); + + // Encode "Beta" (index 1) for row 0. + VarCharVector unencoded = new VarCharVector(COLUMN, _allocator); + unencoded.allocateNew(); + unencoded.setSafe(0, "Beta".getBytes()); + unencoded.setValueCount(1); + + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary); + + try (FieldVector encodedVec = (FieldVector) DictionaryEncoder.encode(unencoded, dictionary)) { + unencoded.close(); + try (VectorSchemaRoot root = + new VectorSchemaRoot(List.of(encodedVec.getField()), List.of(encodedVec), 1)) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (WritableByteChannel ch = Channels.newChannel(baos); + ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, ch)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayInputStream(baos.toByteArray()), _allocator)) { + reader.loadNextBatch(); + ArrowRecordExtractor extractor = new ArrowRecordExtractor(); + extractor.init(null, null); + extractor.setReader(reader); + try (ArrowRecordExtractor.Record record = new ArrowRecordExtractor.Record()) { + extractor.prepareBatch(record); + record.setRowId(0); + GenericRow row = new GenericRow(); + extractor.extract(record, row); + assertEquals(row.getValue(COLUMN), "Beta"); + } + } + } + } finally { + dictVec.close(); + } + } + + // === Include-list filter === + + @Test + public void testIncludeListFiltersFields() throws IOException { + Field a = new Field("a", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field b = new Field("b", FieldType.nullable(new ArrowType.Utf8()), null); + Schema schema = new Schema(List.of(a, b)); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, _allocator)) { + ((IntVector) root.getVector("a")).setSafe(0, 7); + ((VarCharVector) root.getVector("b")).setSafe(0, "skipped".getBytes()); + root.getVector("a").setValueCount(1); + root.getVector("b").setValueCount(1); + root.setRowCount(1); + + GenericRow row = roundtripAndExtract(root, Set.of("a"), null); + assertEquals(row.getValue("a"), 7); + assertNull(row.getValue("b")); + } + } + + // === Helpers === + + private static Field field(ArrowType type) { + return new Field(COLUMN, FieldType.nullable(type), null); + } + + private Object extract(Field field, Consumer populator) throws IOException { + return extract(field, populator, null); + } + + private Object extractRaw(Field field, Consumer populator) throws IOException { + ArrowRecordExtractorConfig config = new ArrowRecordExtractorConfig(); + config.setExtractRawTimeValues(true); + return extract(field, populator, config); + } + + /// Build a single-column root with one populated row, roundtrip through Arrow IPC, then extract row 0. + private Object extract(Field field, Consumer populator, + @Nullable ArrowRecordExtractorConfig config) throws IOException { + Schema schema = new Schema(List.of(field)); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, _allocator)) { + FieldVector vector = root.getVector(COLUMN); + populator.accept(vector); + root.setRowCount(vector.getValueCount()); + return roundtripAndExtract(root, null, config).getValue(COLUMN); + } + } + + private GenericRow roundtripAndExtract(VectorSchemaRoot root, @Nullable Set fieldsToRead, + @Nullable ArrowRecordExtractorConfig config) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (WritableByteChannel ch = Channels.newChannel(baos); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, ch)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + try (ArrowStreamReader reader = new ArrowStreamReader( + new ByteArrayInputStream(baos.toByteArray()), _allocator)) { + reader.loadNextBatch(); + ArrowRecordExtractor extractor = new ArrowRecordExtractor(); + extractor.init(fieldsToRead, config); + extractor.setReader(reader); + try (ArrowRecordExtractor.Record record = new ArrowRecordExtractor.Record()) { + extractor.prepareBatch(record); + record.setRowId(0); + GenericRow row = new GenericRow(); + extractor.extract(record, row); + return row; + } + } + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderTest.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderTest.java index c2f5af1c0e29..581fa9f8b31f 100644 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderTest.java +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowRecordReaderTest.java @@ -18,14 +18,12 @@ */ package org.apache.pinot.plugin.inputformat.arrow; -import com.google.common.collect.Sets; import java.io.File; import java.io.FileOutputStream; +import java.io.IOException; import java.nio.channels.FileChannel; -import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Set; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.Float4Vector; @@ -38,24 +36,21 @@ import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.readers.AbstractRecordReaderTest; -import org.apache.pinot.spi.data.readers.GenericRow; -import org.apache.pinot.spi.data.readers.PrimaryKey; import org.apache.pinot.spi.data.readers.RecordReader; -import org.testng.Assert; import org.testng.SkipException; import org.testng.annotations.Test; +import static org.apache.arrow.vector.types.pojo.FieldType.nullable; + public class ArrowRecordReaderTest extends AbstractRecordReaderTest { private static final int ROWS_PER_BATCH = 1000; @Override protected RecordReader createRecordReader(File file) - throws Exception { + throws IOException { ArrowRecordReader recordReader = new ArrowRecordReader(); recordReader.init(file, _sourceFields, null); return recordReader; @@ -63,42 +58,38 @@ protected RecordReader createRecordReader(File file) @Override protected void writeRecordsToFile(List> recordsToWrite) - throws Exception { + throws IOException { // Single-value fields - Field dimSvInt = new Field("dim_sv_int", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field dimSvLong = new Field("dim_sv_long", FieldType.nullable(new ArrowType.Int(64, true)), null); + Field dimSvInt = new Field("dim_sv_int", nullable(new ArrowType.Int(32, true)), null); + Field dimSvLong = new Field("dim_sv_long", nullable(new ArrowType.Int(64, true)), null); Field dimSvFloat = - new Field("dim_sv_float", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); + new Field("dim_sv_float", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); Field dimSvDouble = - new Field("dim_sv_double", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), - null); - Field dimSvString = new Field("dim_sv_string", FieldType.nullable(new ArrowType.Utf8()), null); + new Field("dim_sv_double", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); + Field dimSvString = new Field("dim_sv_string", nullable(new ArrowType.Utf8()), null); // Multi-value fields (List vectors) - Field dimMvInt = new Field("dim_mv_int", FieldType.nullable(new ArrowType.List()), - Arrays.asList(new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null))); - Field dimMvLong = new Field("dim_mv_long", FieldType.nullable(new ArrowType.List()), - Arrays.asList(new Field("$data$", FieldType.nullable(new ArrowType.Int(64, true)), null))); - Field dimMvFloat = new Field("dim_mv_float", FieldType.nullable(new ArrowType.List()), - Arrays.asList( - new Field("$data$", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null))); - Field dimMvDouble = new Field("dim_mv_double", FieldType.nullable(new ArrowType.List()), - Arrays.asList( - new Field("$data$", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null))); - Field dimMvString = new Field("dim_mv_string", FieldType.nullable(new ArrowType.List()), - Arrays.asList(new Field("$data$", FieldType.nullable(new ArrowType.Utf8()), null))); + Field dimMvInt = new Field("dim_mv_int", nullable(new ArrowType.List()), + List.of(new Field("$data$", nullable(new ArrowType.Int(32, true)), null))); + Field dimMvLong = new Field("dim_mv_long", nullable(new ArrowType.List()), + List.of(new Field("$data$", nullable(new ArrowType.Int(64, true)), null))); + Field dimMvFloat = new Field("dim_mv_float", nullable(new ArrowType.List()), + List.of(new Field("$data$", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null))); + Field dimMvDouble = new Field("dim_mv_double", nullable(new ArrowType.List()), + List.of(new Field("$data$", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null))); + Field dimMvString = new Field("dim_mv_string", nullable(new ArrowType.List()), + List.of(new Field("$data$", nullable(new ArrowType.Utf8()), null))); // Metric fields - Field metInt = new Field("met_int", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field metLong = new Field("met_long", FieldType.nullable(new ArrowType.Int(64, true)), null); - Field metFloat = - new Field("met_float", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); + Field metInt = new Field("met_int", nullable(new ArrowType.Int(32, true)), null); + Field metLong = new Field("met_long", nullable(new ArrowType.Int(64, true)), null); + Field metFloat = new Field("met_float", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null); Field metDouble = - new Field("met_double", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); + new Field("met_double", nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); Schema schema = new Schema( - Arrays.asList(dimSvInt, dimSvLong, dimSvFloat, dimSvDouble, dimSvString, dimMvInt, dimMvLong, dimMvFloat, - dimMvDouble, dimMvString, metInt, metLong, metFloat, metDouble)); + List.of(dimSvInt, dimSvLong, dimSvFloat, dimSvDouble, dimSvString, dimMvInt, dimMvLong, dimMvFloat, dimMvDouble, + dimMvString, metInt, metLong, metFloat, metDouble)); try (RootAllocator allocator = new RootAllocator(); VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); @@ -201,15 +192,15 @@ protected void writeRecordsToFile(List> recordsToWrite) dimSvDoubleVec.setValueCount(batchSize); dimSvStringVec.setValueCount(batchSize); dimMvIntVec.setValueCount(batchSize); - ((IntVector) dimMvIntVec.getDataVector()).setValueCount(mvIntIdx); + dimMvIntVec.getDataVector().setValueCount(mvIntIdx); dimMvLongVec.setValueCount(batchSize); - ((BigIntVector) dimMvLongVec.getDataVector()).setValueCount(mvLongIdx); + dimMvLongVec.getDataVector().setValueCount(mvLongIdx); dimMvFloatVec.setValueCount(batchSize); - ((Float4Vector) dimMvFloatVec.getDataVector()).setValueCount(mvFloatIdx); + dimMvFloatVec.getDataVector().setValueCount(mvFloatIdx); dimMvDoubleVec.setValueCount(batchSize); - ((Float8Vector) dimMvDoubleVec.getDataVector()).setValueCount(mvDoubleIdx); + dimMvDoubleVec.getDataVector().setValueCount(mvDoubleIdx); dimMvStringVec.setValueCount(batchSize); - ((VarCharVector) dimMvStringVec.getDataVector()).setValueCount(mvStringIdx); + dimMvStringVec.getDataVector().setValueCount(mvStringIdx); metIntVec.setValueCount(batchSize); metLongVec.setValueCount(batchSize); metFloatVec.setValueCount(batchSize); @@ -228,71 +219,9 @@ protected String getDataFileName() { return "data.arrow"; } - @Override - protected void checkValue(RecordReader recordReader, List> expectedRecordsMap, - List expectedPrimaryKeys) - throws Exception { - for (int i = 0; i < expectedRecordsMap.size(); i++) { - Map expectedRecord = expectedRecordsMap.get(i); - GenericRow actualRecord = recordReader.next(); - for (FieldSpec fieldSpec : _pinotSchema.getAllFieldSpecs()) { - String fieldSpecName = fieldSpec.getName(); - if (fieldSpec.isSingleValueField()) { - assertSingleValueEquals(actualRecord.getValue(fieldSpecName), expectedRecord.get(fieldSpecName)); - } else { - // Arrow converter returns List instead of Object[] - List actualList = (List) actualRecord.getValue(fieldSpecName); - List expectedList = (List) expectedRecord.get(fieldSpecName); - Assert.assertEquals(actualList.size(), expectedList.size()); - for (int j = 0; j < actualList.size(); j++) { - assertSingleValueEquals(actualList.get(j), expectedList.get(j)); - } - } - } - PrimaryKey primaryKey = actualRecord.getPrimaryKey(getPrimaryKeyColumns()); - Assert.assertEquals(primaryKey.getValues(), expectedPrimaryKeys.get(i)); - } - Assert.assertFalse(recordReader.hasNext()); - } - @Test @Override public void testGzipRecordReader() { throw new SkipException("Arrow IPC file format requires seekable channels and does not support gzip compression"); } - - @Test - public void testFieldsToReadFiltering() - throws Exception { - Set fieldsToRead = Sets.newHashSet("dim_sv_int", "dim_sv_string"); - try (ArrowRecordReader reader = new ArrowRecordReader()) { - reader.init(_dataFile, fieldsToRead, null); - - Assert.assertTrue(reader.hasNext()); - GenericRow row = reader.next(); - - // Requested fields should be present - Assert.assertNotNull(row.getValue("dim_sv_int")); - Assert.assertNotNull(row.getValue("dim_sv_string")); - - // Non-requested fields should be absent - Assert.assertNull(row.getValue("dim_sv_long")); - Assert.assertNull(row.getValue("dim_sv_float")); - Assert.assertNull(row.getValue("dim_sv_double")); - Assert.assertNull(row.getValue("met_int")); - Assert.assertNull(row.getValue("dim_mv_int")); - } - } - - private void assertSingleValueEquals(Object actual, Object expected) { - if (expected instanceof Float) { - Assert.assertEquals(((Number) actual).floatValue(), (float) expected, 1e-6f); - } else if (expected instanceof Double) { - Assert.assertEquals(((Number) actual).doubleValue(), (double) expected, 1e-6d); - } else if (expected instanceof String) { - Assert.assertEquals(actual.toString(), expected); - } else { - Assert.assertEquals(actual, expected); - } - } } diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowTestDataUtils.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowTestDataUtils.java new file mode 100644 index 000000000000..7e5e2a76f921 --- /dev/null +++ b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/ArrowTestDataUtils.java @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.plugin.inputformat.arrow; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + + +/// Helpers that produce Arrow IPC stream-format byte payloads for the decoder tests in [ArrowMessageDecoderTest]. +/// Per-type extraction lives in [ArrowRecordExtractorTest], so this util only covers the shapes the decoder tests +/// need: a small two-column batch, a multi-row batch with a `batch_num` / `value` schema, and an empty zero-batch +/// stream. +public class ArrowTestDataUtils { + private ArrowTestDataUtils() { + } + + /// Two-column (`id` INT, `name` STRING) batch with `numRows` rows. Row `i` has `id = i + 1` and + /// `name = "name_" + (i + 1)`. + public static byte[] createValidArrowIpcData(int numRows) + throws IOException { + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); + Schema schema = new Schema(Arrays.asList(idField, nameField)); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + root.allocateNew(); + idVector.allocateNew(numRows); + nameVector.allocateNew(numRows * 10, numRows); + + for (int i = 0; i < numRows; i++) { + idVector.set(i, i + 1); + nameVector.set(i, ("name_" + (i + 1)).getBytes()); + } + + idVector.setValueCount(numRows); + nameVector.setValueCount(numRows); + root.setRowCount(numRows); + + return writeArrowDataToBytes(root); + } + } + } + + /// Three-column (`id` INT, `batch_num` INT, `value` STRING) data with `batchCount` batches of + /// `rowsPerBatch` rows each. `id` is a global running counter (1-based); `batch_num` is the batch + /// index; `value = "batch__row_"`. + public static byte[] createMultiBatchArrowIpcData(int batchCount, int rowsPerBatch) + throws IOException { + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field batchField = + new Field("batch_num", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field valueField = new Field("value", FieldType.nullable(new ArrowType.Utf8()), null); + Schema schema = new Schema(Arrays.asList(idField, batchField, valueField)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try (WritableByteChannel channel = Channels.newChannel(outputStream); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, channel)) { + + writer.start(); + + IntVector idVector = (IntVector) root.getVector("id"); + IntVector batchVector = (IntVector) root.getVector("batch_num"); + VarCharVector valueVector = (VarCharVector) root.getVector("value"); + + int totalRowId = 1; + for (int batch = 0; batch < batchCount; batch++) { + root.allocateNew(); + idVector.allocateNew(rowsPerBatch); + batchVector.allocateNew(rowsPerBatch); + valueVector.allocateNew(rowsPerBatch * 15, rowsPerBatch); + + for (int row = 0; row < rowsPerBatch; row++) { + idVector.set(row, totalRowId++); + batchVector.set(row, batch); + valueVector.set(row, ("batch_" + batch + "_row_" + row).getBytes()); + } + + idVector.setValueCount(rowsPerBatch); + batchVector.setValueCount(rowsPerBatch); + valueVector.setValueCount(rowsPerBatch); + root.setRowCount(rowsPerBatch); + + writer.writeBatch(); + } + + writer.end(); + return outputStream.toByteArray(); + } + } + } + + /// Stream with the schema header but no record batches. + public static byte[] createEmptyArrowIpcData() + throws IOException { + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); + Schema schema = new Schema(Arrays.asList(idField, nameField)); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.setRowCount(0); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try (WritableByteChannel channel = Channels.newChannel(outputStream); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, channel)) { + writer.start(); + writer.end(); + } + + return outputStream.toByteArray(); + } + } + } + + private static byte[] writeArrowDataToBytes(VectorSchemaRoot root) + throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try (WritableByteChannel channel = Channels.newChannel(outputStream); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, channel)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + return outputStream.toByteArray(); + } +} diff --git a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/util/ArrowTestDataUtil.java b/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/util/ArrowTestDataUtil.java deleted file mode 100644 index 206dc7d85a56..000000000000 --- a/pinot-plugins/pinot-input-format/pinot-arrow/src/test/java/org/apache/pinot/plugin/inputformat/arrow/util/ArrowTestDataUtil.java +++ /dev/null @@ -1,607 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.plugin.inputformat.arrow.util; - -import java.io.ByteArrayOutputStream; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; -import java.util.Arrays; -import java.util.List; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.Float8Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.TimeStampMilliVector; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.dictionary.Dictionary; -import org.apache.arrow.vector.dictionary.DictionaryEncoder; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.ipc.ArrowStreamWriter; -import org.apache.arrow.vector.types.FloatingPointPrecision; -import org.apache.arrow.vector.types.TimeUnit; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.DictionaryEncoding; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; - - -public class ArrowTestDataUtil { - - private ArrowTestDataUtil() { - } - - public static byte[] createValidArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); - Schema schema = new Schema(Arrays.asList(idField, nameField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - VarCharVector nameVector = (VarCharVector) root.getVector("name"); - - root.allocateNew(); - idVector.allocateNew(numRows); - nameVector.allocateNew(numRows * 10, numRows); - - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - nameVector.set(i, ("name_" + (i + 1)).getBytes()); - } - - idVector.setValueCount(numRows); - nameVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createMultiTypeArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); - Field priceField = - new Field( - "price", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), - null); - Field activeField = new Field("active", FieldType.nullable(new ArrowType.Bool()), null); - Field timestampField = - new Field( - "timestamp", - FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), - null); - - Schema schema = - new Schema(Arrays.asList(idField, nameField, priceField, activeField, timestampField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - VarCharVector nameVector = (VarCharVector) root.getVector("name"); - Float8Vector priceVector = (Float8Vector) root.getVector("price"); - BitVector activeVector = (BitVector) root.getVector("active"); - TimeStampMilliVector timestampVector = (TimeStampMilliVector) root.getVector("timestamp"); - - root.allocateNew(); - idVector.allocateNew(numRows); - nameVector.allocateNew(numRows * 20, numRows); - priceVector.allocateNew(numRows); - activeVector.allocateNew(numRows); - timestampVector.allocateNew(numRows); - - long baseTime = System.currentTimeMillis(); - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - nameVector.set(i, ("product_" + (i + 1)).getBytes()); - priceVector.set(i, 10.99 + (i * 5.0)); - activeVector.set(i, i % 2 == 0 ? 1 : 0); - timestampVector.set(i, baseTime + (i * 1000L)); - } - - idVector.setValueCount(numRows); - nameVector.setValueCount(numRows); - priceVector.setValueCount(numRows); - activeVector.setValueCount(numRows); - timestampVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createMultiBatchArrowIpcData(int batchCount, int rowsPerBatch) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field batchField = - new Field("batch_num", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field valueField = new Field("value", FieldType.nullable(new ArrowType.Utf8()), null); - Schema schema = new Schema(Arrays.asList(idField, batchField, valueField)); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - try (WritableByteChannel channel = Channels.newChannel(outputStream); - VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); - ArrowStreamWriter writer = new ArrowStreamWriter(root, null, channel)) { - - writer.start(); - - IntVector idVector = (IntVector) root.getVector("id"); - IntVector batchVector = (IntVector) root.getVector("batch_num"); - VarCharVector valueVector = (VarCharVector) root.getVector("value"); - - int totalRowId = 1; - for (int batch = 0; batch < batchCount; batch++) { - root.allocateNew(); - idVector.allocateNew(rowsPerBatch); - batchVector.allocateNew(rowsPerBatch); - valueVector.allocateNew(rowsPerBatch * 15, rowsPerBatch); - - for (int row = 0; row < rowsPerBatch; row++) { - idVector.set(row, totalRowId++); - batchVector.set(row, batch); - valueVector.set(row, ("batch_" + batch + "_row_" + row).getBytes()); - } - - idVector.setValueCount(rowsPerBatch); - batchVector.setValueCount(rowsPerBatch); - valueVector.setValueCount(rowsPerBatch); - root.setRowCount(rowsPerBatch); - - writer.writeBatch(); - } - - writer.end(); - return outputStream.toByteArray(); - } - } - } - - public static byte[] createEmptyArrowIpcData() - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); - Schema schema = new Schema(Arrays.asList(idField, nameField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - root.setRowCount(0); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - try (WritableByteChannel channel = Channels.newChannel(outputStream); - ArrowStreamWriter writer = new ArrowStreamWriter(root, null, channel)) { - - writer.start(); - writer.end(); - } - - return outputStream.toByteArray(); - } - } - } - - public static byte[] createDictionaryEncodedArrowIpcData(int numRows) - throws Exception { - List dictionaryValues = Arrays.asList("Electronics", "Books", "Clothing", "Home"); - DictionaryEncoding dictionaryEncoding = - new DictionaryEncoding(1L, false, new ArrowType.Int(32, true)); - - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); - VarCharVector dictionaryVector = new VarCharVector("category_dict", allocator); - IntVector idVector = new IntVector("id", allocator); - Float8Vector priceVector = new Float8Vector("price", allocator); - VarCharVector categoryUnencoded = - new VarCharVector( - "category", - new FieldType(true, new ArrowType.Utf8(), dictionaryEncoding), - allocator)) { - - dictionaryVector.allocateNew(); - for (int i = 0; i < dictionaryValues.size(); i++) { - dictionaryVector.set(i, dictionaryValues.get(i).getBytes()); - } - dictionaryVector.setValueCount(dictionaryValues.size()); - - Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding); - DictionaryProvider.MapDictionaryProvider dictionaryProvider = - new DictionaryProvider.MapDictionaryProvider(); - dictionaryProvider.put(dictionary); - - idVector.allocateNew(numRows); - priceVector.allocateNew(numRows); - categoryUnencoded.allocateNew(numRows); - - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - categoryUnencoded.set(i, dictionaryValues.get(i % dictionaryValues.size()).getBytes()); - priceVector.set(i, 19.99 + (i * 10.0)); - } - idVector.setValueCount(numRows); - priceVector.setValueCount(numRows); - categoryUnencoded.setValueCount(numRows); - - try (org.apache.arrow.vector.FieldVector encodedCategoryVector = - (org.apache.arrow.vector.FieldVector) - DictionaryEncoder.encode(categoryUnencoded, dictionary)) { - List fields = - Arrays.asList( - idVector.getField(), encodedCategoryVector.getField(), priceVector.getField()); - List vectors = - Arrays.asList(idVector, encodedCategoryVector, priceVector); - try (VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors)) { - return writeArrowDataToBytes(root, dictionaryProvider); - } - } - } - } - - public static byte[] createListArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field numbersElementField = - new Field("$data$", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field numbersField = - new Field( - "numbers", - FieldType.nullable(new ArrowType.List()), - Arrays.asList(numbersElementField)); - - Field tagsElementField = new Field("$data$", FieldType.nullable(new ArrowType.Utf8()), null); - Field tagsField = - new Field( - "tags", FieldType.nullable(new ArrowType.List()), Arrays.asList(tagsElementField)); - - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Schema schema = new Schema(Arrays.asList(idField, numbersField, tagsField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - ListVector numbersVector = (ListVector) root.getVector("numbers"); - ListVector tagsVector = (ListVector) root.getVector("tags"); - IntVector numbersChild = (IntVector) numbersVector.getDataVector(); - VarCharVector tagsChild = (VarCharVector) tagsVector.getDataVector(); - - root.allocateNew(); - idVector.allocateNew(numRows); - numbersVector.allocateNew(); - tagsVector.allocateNew(); - - int numbersElemIndex = 0; - int tagsElemIndex = 0; - - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - - numbersVector.startNewValue(i); - for (int j = 0; j <= i; j++) { - numbersChild.setSafe(numbersElemIndex++, (i + 1) * 10 + j); - } - numbersVector.endValue(i, i + 1); - - tagsVector.startNewValue(i); - for (int j = 0; j < 2; j++) { - tagsChild.setSafe(tagsElemIndex++, ("tag_" + i + "_" + j).getBytes()); - } - tagsVector.endValue(i, 2); - } - - idVector.setValueCount(numRows); - numbersChild.setValueCount(numbersElemIndex); - numbersVector.setValueCount(numRows); - tagsChild.setValueCount(tagsElemIndex); - tagsVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createStructArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field nameField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); - Field ageField = new Field("age", FieldType.nullable(new ArrowType.Int(32, true)), null); - Field streetField = new Field("street", FieldType.nullable(new ArrowType.Utf8()), null); - Field cityField = new Field("city", FieldType.nullable(new ArrowType.Utf8()), null); - Field addressField = - new Field( - "address", - FieldType.nullable(new ArrowType.Struct()), - Arrays.asList(streetField, cityField)); - - Field personField = - new Field( - "person", - FieldType.nullable(new ArrowType.Struct()), - Arrays.asList(nameField, ageField, addressField)); - - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Schema schema = new Schema(Arrays.asList(idField, personField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - StructVector personVector = (StructVector) root.getVector("person"); - - root.allocateNew(); - idVector.allocateNew(numRows); - personVector.allocateNew(); - - VarCharVector nameVector = (VarCharVector) personVector.getChild("name"); - IntVector ageVector = (IntVector) personVector.getChild("age"); - StructVector addressVector = (StructVector) personVector.getChild("address"); - VarCharVector streetVector = (VarCharVector) addressVector.getChild("street"); - VarCharVector cityVector = (VarCharVector) addressVector.getChild("city"); - - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - personVector.setIndexDefined(i); - addressVector.setIndexDefined(i); - nameVector.setSafe(i, ("Person_" + (i + 1)).getBytes()); - ageVector.setSafe(i, 25 + i); - streetVector.setSafe(i, ((i + 1) + " Main St").getBytes()); - cityVector.setSafe(i, ("City_" + (i + 1)).getBytes()); - } - - idVector.setValueCount(numRows); - personVector.setValueCount(numRows); - nameVector.setValueCount(numRows); - ageVector.setValueCount(numRows); - addressVector.setValueCount(numRows); - streetVector.setValueCount(numRows); - cityVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createMapArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field keyField = - new Field(MapVector.KEY_NAME, FieldType.notNullable(new ArrowType.Utf8()), null); - Field valField = - new Field(MapVector.VALUE_NAME, FieldType.nullable(new ArrowType.Int(32, true)), null); - Field entriesField = - new Field( - MapVector.DATA_VECTOR_NAME, - FieldType.notNullable(new ArrowType.Struct()), - Arrays.asList(keyField, valField)); - Field mapField = - new Field( - "metadata", - FieldType.nullable(new ArrowType.Map(false)), - Arrays.asList(entriesField)); - - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Schema schema = new Schema(Arrays.asList(idField, mapField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - MapVector mapVector = (MapVector) root.getVector("metadata"); - StructVector entries = (StructVector) mapVector.getDataVector(); - VarCharVector keyVector = (VarCharVector) entries.getChild(MapVector.KEY_NAME); - IntVector valueVector = (IntVector) entries.getChild(MapVector.VALUE_NAME); - - root.allocateNew(); - idVector.allocateNew(numRows); - mapVector.allocateNew(); - - int entryIndex = 0; - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - int entriesCount = 2 + (i % 2); - mapVector.startNewValue(i); - for (int j = 0; j < entriesCount; j++) { - keyVector.setSafe(entryIndex, ("key_" + i + "_" + j).getBytes()); - valueVector.setSafe(entryIndex, (i + 1) * 100 + j); - entries.setIndexDefined(entryIndex); - entryIndex++; - } - mapVector.endValue(i, entriesCount); - } - - idVector.setValueCount(numRows); - keyVector.setValueCount(entryIndex); - valueVector.setValueCount(entryIndex); - entries.setValueCount(entryIndex); - mapVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createNestedMapArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - // Define inner map (value of outer map) - Field innerKeyField = - new Field(MapVector.KEY_NAME, FieldType.notNullable(new ArrowType.Utf8()), null); - Field innerValField = - new Field(MapVector.VALUE_NAME, FieldType.nullable(new ArrowType.Int(32, true)), null); - Field innerEntriesField = - new Field( - MapVector.DATA_VECTOR_NAME, - FieldType.notNullable(new ArrowType.Struct()), - Arrays.asList(innerKeyField, innerValField)); - Field innerMapField = - new Field( - MapVector.VALUE_NAME, - FieldType.nullable(new ArrowType.Map(false)), - Arrays.asList(innerEntriesField)); - - // Define outer map with value as the inner map - Field outerKeyField = - new Field(MapVector.KEY_NAME, FieldType.notNullable(new ArrowType.Utf8()), null); - Field outerEntriesField = - new Field( - MapVector.DATA_VECTOR_NAME, - FieldType.notNullable(new ArrowType.Struct()), - Arrays.asList(outerKeyField, innerMapField)); - Field outerMapField = - new Field( - "metadata", - FieldType.nullable(new ArrowType.Map(false)), - Arrays.asList(outerEntriesField)); - - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Schema schema = new Schema(Arrays.asList(idField, outerMapField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - MapVector outerMapVector = (MapVector) root.getVector("metadata"); - StructVector outerEntries = (StructVector) outerMapVector.getDataVector(); - VarCharVector outerKeyVector = (VarCharVector) outerEntries.getChild(MapVector.KEY_NAME); - MapVector innerMapVector = (MapVector) outerEntries.getChild(MapVector.VALUE_NAME); - StructVector innerEntries = (StructVector) innerMapVector.getDataVector(); - VarCharVector innerKeyVector = (VarCharVector) innerEntries.getChild(MapVector.KEY_NAME); - IntVector innerValueVector = (IntVector) innerEntries.getChild(MapVector.VALUE_NAME); - - root.allocateNew(); - idVector.allocateNew(numRows); - outerMapVector.allocateNew(); - - int outerEntryIndex = 0; - int innerEntryIndex = 0; - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - - int outerEntriesCount = 2 + (i % 2); // 2 or 3 outer entries - outerMapVector.startNewValue(i); - for (int j = 0; j < outerEntriesCount; j++) { - // Set outer key - outerKeyVector.setSafe(outerEntryIndex, ("outer_key_" + i + "_" + j).getBytes()); - - // Populate inner map for this outer entry at aligned index - innerMapVector.startNewValue(outerEntryIndex); - int innerEntriesCount = 2 + (j % 2); // 2 or 3 inner entries - for (int k = 0; k < innerEntriesCount; k++) { - innerKeyVector.setSafe( - innerEntryIndex, ("inner_key_" + i + "_" + j + "_" + k).getBytes()); - innerValueVector.setSafe(innerEntryIndex, (i + 1) * 1000 + j * 10 + k); - innerEntries.setIndexDefined(innerEntryIndex); - innerEntryIndex++; - } - innerMapVector.endValue(outerEntryIndex, innerEntriesCount); - - outerEntries.setIndexDefined(outerEntryIndex); - outerEntryIndex++; - } - outerMapVector.endValue(i, outerEntriesCount); - } - - idVector.setValueCount(numRows); - outerKeyVector.setValueCount(outerEntryIndex); - innerKeyVector.setValueCount(innerEntryIndex); - innerValueVector.setValueCount(innerEntryIndex); - innerEntries.setValueCount(innerEntryIndex); - innerMapVector.setValueCount(outerEntryIndex); - outerEntries.setValueCount(outerEntryIndex); - outerMapVector.setValueCount(numRows); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - public static byte[] createNestedListStructArrowIpcData(int numRows) - throws Exception { - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field itemNameField = new Field("item_name", FieldType.nullable(new ArrowType.Utf8()), null); - Field itemPriceField = - new Field( - "item_price", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), - null); - Field itemStructField = - new Field( - "$data$", - FieldType.nullable(new ArrowType.Struct()), - Arrays.asList(itemNameField, itemPriceField)); - - Field itemsField = - new Field( - "items", FieldType.nullable(new ArrowType.List()), Arrays.asList(itemStructField)); - - Field idField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); - Schema schema = new Schema(Arrays.asList(idField, itemsField)); - - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - IntVector idVector = (IntVector) root.getVector("id"); - ListVector itemsVector = (ListVector) root.getVector("items"); - StructVector itemStructVector = (StructVector) itemsVector.getDataVector(); - VarCharVector itemNameVector = (VarCharVector) itemStructVector.getChild("item_name"); - Float8Vector itemPriceVector = (Float8Vector) itemStructVector.getChild("item_price"); - - root.allocateNew(); - idVector.allocateNew(numRows); - itemsVector.allocateNew(); - - int structIndex = 0; - for (int i = 0; i < numRows; i++) { - idVector.set(i, i + 1); - int itemsCount = 1 + (i % 3); - itemsVector.startNewValue(i); - for (int j = 0; j < itemsCount; j++) { - itemNameVector.setSafe(structIndex, ("item_" + i + "_" + j).getBytes()); - itemPriceVector.setSafe(structIndex, 10.99 + (i * 5.0) + j); - itemStructVector.setIndexDefined(structIndex); - structIndex++; - } - itemsVector.endValue(i, itemsCount); - } - - idVector.setValueCount(numRows); - itemsVector.setValueCount(numRows); - itemNameVector.setValueCount(structIndex); - itemPriceVector.setValueCount(structIndex); - root.setRowCount(numRows); - - return writeArrowDataToBytes(root, null); - } - } - } - - private static byte[] writeArrowDataToBytes( - VectorSchemaRoot root, DictionaryProvider dictionaryProvider) - throws Exception { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - try (WritableByteChannel channel = Channels.newChannel(outputStream); - ArrowStreamWriter writer = new ArrowStreamWriter(root, dictionaryProvider, channel)) { - writer.start(); - writer.writeBatch(); - writer.end(); - } - return outputStream.toByteArray(); - } -}