diff --git a/c++/src/ColumnPrinter.cc b/c++/src/ColumnPrinter.cc index 6535c612ce..6d969a3dab 100644 --- a/c++/src/ColumnPrinter.cc +++ b/c++/src/ColumnPrinter.cc @@ -540,11 +540,17 @@ namespace orc { if (hasNulls && !notNull[rowId]) { writeString(buffer, "null"); } else { + const size_t tag = static_cast(tags_[rowId]); + if (tag >= fieldPrinter_.size()) { + throw ParseError("Invalid union tag " + to_string(static_cast(tag)) + + " for union with " + + to_string(static_cast(fieldPrinter_.size())) + " children"); + } writeString(buffer, "{\"tag\": "); - const auto numBuffer = std::to_string(static_cast(tags_[rowId])); + const auto numBuffer = std::to_string(static_cast(tag)); writeString(buffer, numBuffer.c_str()); writeString(buffer, ", \"value\": "); - fieldPrinter_[tags_[rowId]]->printRow(offsets_[rowId]); + fieldPrinter_[tag]->printRow(offsets_[rowId]); writeChar(buffer, '}'); } } diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc index b4032af734..ab1fe24b1c 100644 --- a/c++/src/ColumnReader.cc +++ b/c++/src/ColumnReader.cc @@ -1131,6 +1131,16 @@ namespace orc { } } + size_t getCheckedUnionTag(unsigned char tag, uint64_t numChildren) { + size_t child = static_cast(tag); + if (child >= numChildren) { + throw ParseError("Invalid union tag " + to_string(static_cast(child)) + + " for union with " + to_string(static_cast(numChildren)) + + " children"); + } + return child; + } + class UnionColumnReader : public ColumnReader { private: std::unique_ptr rle_; @@ -1181,15 +1191,15 @@ namespace orc { uint64_t UnionColumnReader::skip(uint64_t numValues) { numValues = ColumnReader::skip(numValues); const uint64_t BUFFER_SIZE = 1024; - char buffer[BUFFER_SIZE]; + unsigned char buffer[BUFFER_SIZE]; uint64_t lengthsRead = 0; int64_t* counts = childrenCounts_.data(); memset(counts, 0, sizeof(int64_t) * numChildren_); while (lengthsRead < numValues) { uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE); - rle_->next(buffer, chunk, nullptr); + rle_->next(reinterpret_cast(buffer), chunk, nullptr); for (size_t i = 0; i < chunk; ++i) { - counts[static_cast(buffer[i])] += 1; + counts[getCheckedUnionTag(buffer[i], numChildren_)] += 1; } lengthsRead += chunk; } @@ -1225,12 +1235,14 @@ namespace orc { if (notNull) { for (size_t i = 0; i < numValues; ++i) { if (notNull[i]) { - offsets[i] = static_cast(counts[static_cast(tags[i])]++); + size_t tag = getCheckedUnionTag(tags[i], numChildren_); + offsets[i] = static_cast(counts[tag]++); } } } else { for (size_t i = 0; i < numValues; ++i) { - offsets[i] = static_cast(counts[static_cast(tags[i])]++); + size_t tag = getCheckedUnionTag(tags[i], numChildren_); + offsets[i] = static_cast(counts[tag]++); } } // read the right number of each child column diff --git a/c++/test/TestColumnPrinter.cc b/c++/test/TestColumnPrinter.cc index 8fa5d01cd4..2553888437 100644 --- a/c++/test/TestColumnPrinter.cc +++ b/c++/test/TestColumnPrinter.cc @@ -506,6 +506,25 @@ namespace orc { } } + TEST(TestColumnPrinter, UnionColumnPrinterRejectsInvalidTag) { + std::string line; + std::unique_ptr type = createUnionType(); + type->addUnionChild(createPrimitiveType(LONG)); + type->addUnionChild(createPrimitiveType(INT)); + std::unique_ptr printer = createColumnPrinter(line, type.get()); + + UnionVectorBatch batch(1, *getDefaultPool()); + batch.children.push_back(new LongVectorBatch(1, *getDefaultPool())); + batch.children.push_back(new LongVectorBatch(1, *getDefaultPool())); + batch.numElements = 1; + batch.hasNulls = false; + batch.tags[0] = 200; + batch.offsets[0] = 0; + + printer->reset(batch); + EXPECT_THROW(printer->printRow(0), ParseError); + } + TEST(TestColumnPrinter, StructColumnPrinter) { std::string line; std::unique_ptr type = createStructType(); diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc index fcbf007630..d2aa38cb66 100644 --- a/c++/test/TestColumnReader.cc +++ b/c++/test/TestColumnReader.cc @@ -3717,6 +3717,76 @@ namespace orc { batch.toString()); } + const unsigned char INVALID_UNION_TAG[] = {0xff, 0xc8}; + const unsigned char ONE_PRESENT_VALUE[] = {0x00, 0x80}; + + std::unique_ptr createTwoChildUnionRowType() { + std::unique_ptr unionType = createUnionType(); + unionType->addUnionChild(createPrimitiveType(LONG)); + unionType->addUnionChild(createPrimitiveType(INT)); + std::unique_ptr rowType = createStructType(); + rowType->addStructField("col0", std::move(unionType)); + return rowType; + } + + std::unique_ptr buildInvalidUnionTagReader(MockStripeStreams& streams, + bool hasNulls = false) { + std::vector selectedColumns(4, false); + selectedColumns[0] = true; + selectedColumns[1] = true; + EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns)); + EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr)); + + proto::ColumnEncoding directEncoding; + directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT); + EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding)); + + EXPECT_CALL(streams, getStreamProxy(testing::_, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return(nullptr)); + + if (hasNulls) { + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_PRESENT, true)) + .WillRepeatedly(testing::Return( + new SeekableArrayInputStream(ONE_PRESENT_VALUE, ARRAY_SIZE(ONE_PRESENT_VALUE)))); + } + + EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true)) + .WillRepeatedly(testing::Return( + new SeekableArrayInputStream(INVALID_UNION_TAG, ARRAY_SIZE(INVALID_UNION_TAG)))); + + std::unique_ptr rowType = createTwoChildUnionRowType(); + return buildReader(*rowType, streams); + } + + void addSingleUnionBatch(StructVectorBatch& batch) { + batch.fields.push_back(new UnionVectorBatch(1, *getDefaultPool())); + } + + TEST(TestColumnReader, testUnionRejectsInvalidTag) { + MockStripeStreams streams; + std::unique_ptr reader = buildInvalidUnionTagReader(streams); + + StructVectorBatch batch(1, *getDefaultPool()); + addSingleUnionBatch(batch); + EXPECT_THROW(reader->next(batch, 1, 0), ParseError); + } + + TEST(TestColumnReader, testUnionRejectsInvalidTagWithNulls) { + MockStripeStreams streams; + std::unique_ptr reader = buildInvalidUnionTagReader(streams, true); + + StructVectorBatch batch(1, *getDefaultPool()); + addSingleUnionBatch(batch); + EXPECT_THROW(reader->next(batch, 1, 0), ParseError); + } + + TEST(TestColumnReader, testUnionSkipRejectsInvalidTag) { + MockStripeStreams streams; + std::unique_ptr reader = buildInvalidUnionTagReader(streams); + + EXPECT_THROW(reader->skip(1), ParseError); + } + TEST(TestColumnReader, testUnionWithNulls) { MockStripeStreams streams;