Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions c++/src/ColumnPrinter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,17 @@ namespace orc {
if (hasNulls && !notNull[rowId]) {
writeString(buffer, "null");
} else {
const size_t tag = static_cast<size_t>(tags_[rowId]);
if (tag >= fieldPrinter_.size()) {
throw ParseError("Invalid union tag " + to_string(static_cast<int64_t>(tag)) +
" for union with " +
to_string(static_cast<int64_t>(fieldPrinter_.size())) + " children");
}
writeString(buffer, "{\"tag\": ");
const auto numBuffer = std::to_string(static_cast<int64_t>(tags_[rowId]));
const auto numBuffer = std::to_string(static_cast<int64_t>(tag));
writeString(buffer, numBuffer.c_str());
writeString(buffer, ", \"value\": ");
fieldPrinter_[tags_[rowId]]->printRow(offsets_[rowId]);
fieldPrinter_[tag]->printRow(offsets_[rowId]);
writeChar(buffer, '}');
}
}
Expand Down
22 changes: 17 additions & 5 deletions c++/src/ColumnReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,16 @@ namespace orc {
}
}

size_t getCheckedUnionTag(unsigned char tag, uint64_t numChildren) {
size_t child = static_cast<size_t>(tag);
if (child >= numChildren) {
throw ParseError("Invalid union tag " + to_string(static_cast<int64_t>(child)) +
" for union with " + to_string(static_cast<int64_t>(numChildren)) +
" children");
}
return child;
}

class UnionColumnReader : public ColumnReader {
private:
std::unique_ptr<ByteRleDecoder> rle_;
Expand Down Expand Up @@ -1181,15 +1191,15 @@ namespace orc {
uint64_t UnionColumnReader::skip(uint64_t numValues) {
numValues = ColumnReader::skip(numValues);
const uint64_t BUFFER_SIZE = 1024;
char buffer[BUFFER_SIZE];
unsigned char buffer[BUFFER_SIZE];
uint64_t lengthsRead = 0;
int64_t* counts = childrenCounts_.data();
memset(counts, 0, sizeof(int64_t) * numChildren_);
while (lengthsRead < numValues) {
uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE);
rle_->next(buffer, chunk, nullptr);
rle_->next(reinterpret_cast<char*>(buffer), chunk, nullptr);
for (size_t i = 0; i < chunk; ++i) {
counts[static_cast<size_t>(buffer[i])] += 1;
counts[getCheckedUnionTag(buffer[i], numChildren_)] += 1;
}
lengthsRead += chunk;
}
Expand Down Expand Up @@ -1225,12 +1235,14 @@ namespace orc {
if (notNull) {
for (size_t i = 0; i < numValues; ++i) {
if (notNull[i]) {
offsets[i] = static_cast<uint64_t>(counts[static_cast<size_t>(tags[i])]++);
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
offsets[i] = static_cast<uint64_t>(counts[tag]++);
}
}
} else {
for (size_t i = 0; i < numValues; ++i) {
offsets[i] = static_cast<uint64_t>(counts[static_cast<size_t>(tags[i])]++);
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
offsets[i] = static_cast<uint64_t>(counts[tag]++);
}
}
// read the right number of each child column
Expand Down
19 changes: 19 additions & 0 deletions c++/test/TestColumnPrinter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,25 @@ namespace orc {
}
}

TEST(TestColumnPrinter, UnionColumnPrinterRejectsInvalidTag) {
std::string line;
std::unique_ptr<Type> type = createUnionType();
type->addUnionChild(createPrimitiveType(LONG));
type->addUnionChild(createPrimitiveType(INT));
std::unique_ptr<ColumnPrinter> printer = createColumnPrinter(line, type.get());

UnionVectorBatch batch(1, *getDefaultPool());
batch.children.push_back(new LongVectorBatch(1, *getDefaultPool()));
batch.children.push_back(new LongVectorBatch(1, *getDefaultPool()));
batch.numElements = 1;
batch.hasNulls = false;
batch.tags[0] = 200;
batch.offsets[0] = 0;

printer->reset(batch);
EXPECT_THROW(printer->printRow(0), ParseError);
}

TEST(TestColumnPrinter, StructColumnPrinter) {
std::string line;
std::unique_ptr<Type> type = createStructType();
Expand Down
70 changes: 70 additions & 0 deletions c++/test/TestColumnReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3717,6 +3717,76 @@ namespace orc {
batch.toString());
}

const unsigned char INVALID_UNION_TAG[] = {0xff, 0xc8};
const unsigned char ONE_PRESENT_VALUE[] = {0x00, 0x80};

std::unique_ptr<Type> createTwoChildUnionRowType() {
std::unique_ptr<Type> unionType = createUnionType();
unionType->addUnionChild(createPrimitiveType(LONG));
unionType->addUnionChild(createPrimitiveType(INT));
std::unique_ptr<Type> rowType = createStructType();
rowType->addStructField("col0", std::move(unionType));
return rowType;
}

std::unique_ptr<ColumnReader> buildInvalidUnionTagReader(MockStripeStreams& streams,
bool hasNulls = false) {
std::vector<bool> selectedColumns(4, false);
selectedColumns[0] = true;
selectedColumns[1] = true;
EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns));
EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr));

proto::ColumnEncoding directEncoding;
directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT);
EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding));

EXPECT_CALL(streams, getStreamProxy(testing::_, proto::Stream_Kind_PRESENT, true))
.WillRepeatedly(testing::Return(nullptr));

if (hasNulls) {
EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_PRESENT, true))
.WillRepeatedly(testing::Return(
new SeekableArrayInputStream(ONE_PRESENT_VALUE, ARRAY_SIZE(ONE_PRESENT_VALUE))));
}

EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true))
.WillRepeatedly(testing::Return(
new SeekableArrayInputStream(INVALID_UNION_TAG, ARRAY_SIZE(INVALID_UNION_TAG))));

std::unique_ptr<Type> rowType = createTwoChildUnionRowType();
return buildReader(*rowType, streams);
}

void addSingleUnionBatch(StructVectorBatch& batch) {
batch.fields.push_back(new UnionVectorBatch(1, *getDefaultPool()));
}

TEST(TestColumnReader, testUnionRejectsInvalidTag) {
MockStripeStreams streams;
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams);

StructVectorBatch batch(1, *getDefaultPool());
addSingleUnionBatch(batch);
EXPECT_THROW(reader->next(batch, 1, 0), ParseError);
}

TEST(TestColumnReader, testUnionRejectsInvalidTagWithNulls) {
MockStripeStreams streams;
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams, true);

StructVectorBatch batch(1, *getDefaultPool());
addSingleUnionBatch(batch);
EXPECT_THROW(reader->next(batch, 1, 0), ParseError);
}

TEST(TestColumnReader, testUnionSkipRejectsInvalidTag) {
MockStripeStreams streams;
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams);

EXPECT_THROW(reader->skip(1), ParseError);
}

TEST(TestColumnReader, testUnionWithNulls) {
MockStripeStreams streams;

Expand Down
Loading