Skip to content

Commit 1bed7b6

Browse files
authored
fix: NaN pushdown correctly pushes down NaNs correctness issue (#2351)
## Which issue does this PR close? Not tied to a specific issue - found during an audit of pushdown filter gaps. ## What changes are included in this PR? `PredicateConverter::is_nan` and `not_nan` were never actually implemented - `is_nan` returned `always_true` (matches every row) and `not_nan` returned `always_false` (matches no rows). Every other predicate in `PredicateConverter` projects the column from the batch and runs an arrow compute kernel, but these two just returned constants. This adds a `compute_is_nan` helper that downcasts to `Float32Array`/`Float64Array` and checks each value with `f.is_nan()`, preserving nulls. Non-float types return all false. `is_nan` and `not_nan` now use it the same way `is_null`/`not_null` use `arrow::is_null`/`is_not_null`. ## Are these changes tested? Yes, test added
1 parent ad44fc3 commit 1bed7b6

1 file changed

Lines changed: 121 additions & 5 deletions

File tree

crates/iceberg/src/arrow/reader.rs

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ use std::str::FromStr;
2323
use std::sync::Arc;
2424

2525
use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene};
26+
use arrow_array::cast::AsArray;
27+
use arrow_array::types::{Float32Type, Float64Type};
2628
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
29+
use arrow_buffer::BooleanBuffer;
2730
use arrow_cast::cast::cast;
2831
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
2932
use arrow_schema::{
@@ -1509,6 +1512,35 @@ fn project_column(
15091512
}
15101513
}
15111514

1515+
fn compute_is_nan(array: &ArrayRef) -> std::result::Result<BooleanArray, ArrowError> {
1516+
// Compute NaN over the contiguous values slice, then fold the null bitmap
1517+
// in with a single bitwise AND so that null slots become false.
1518+
let (is_nan, nulls) = match array.data_type() {
1519+
DataType::Float32 => {
1520+
let arr = array.as_primitive::<Float32Type>();
1521+
(
1522+
BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())),
1523+
arr.nulls(),
1524+
)
1525+
}
1526+
DataType::Float64 => {
1527+
let arr = array.as_primitive::<Float64Type>();
1528+
(
1529+
BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())),
1530+
arr.nulls(),
1531+
)
1532+
}
1533+
_ => unreachable!("is_nan is only valid for float types"),
1534+
};
1535+
1536+
let values = match nulls {
1537+
Some(nulls) => &is_nan & nulls.inner(),
1538+
None => is_nan,
1539+
};
1540+
1541+
Ok(BooleanArray::new(values, None))
1542+
}
1543+
15121544
type PredicateResult =
15131545
dyn FnMut(RecordBatch) -> std::result::Result<BooleanArray, ArrowError> + Send + 'static;
15141546

@@ -1591,8 +1623,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
15911623
reference: &BoundReference,
15921624
_predicate: &BoundPredicate,
15931625
) -> Result<Box<PredicateResult>> {
1594-
if self.bound_reference(reference)?.is_some() {
1595-
self.build_always_true()
1626+
if let Some(idx) = self.bound_reference(reference)? {
1627+
Ok(Box::new(move |batch| {
1628+
let column = project_column(&batch, idx)?;
1629+
compute_is_nan(&column)
1630+
}))
15961631
} else {
15971632
// A missing column, treating it as null.
15981633
self.build_always_false()
@@ -1604,8 +1639,12 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
16041639
reference: &BoundReference,
16051640
_predicate: &BoundPredicate,
16061641
) -> Result<Box<PredicateResult>> {
1607-
if self.bound_reference(reference)?.is_some() {
1608-
self.build_always_false()
1642+
if let Some(idx) = self.bound_reference(reference)? {
1643+
Ok(Box::new(move |batch| {
1644+
let column = project_column(&batch, idx)?;
1645+
let is_nan = compute_is_nan(&column)?;
1646+
not(&is_nan)
1647+
}))
16091648
} else {
16101649
// A missing column, treating it as null.
16111650
self.build_always_true()
@@ -2002,7 +2041,7 @@ mod tests {
20022041
use std::sync::Arc;
20032042

20042043
use arrow_array::cast::AsArray;
2005-
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
2044+
use arrow_array::{Array, ArrayRef, BooleanArray, LargeStringArray, RecordBatch, StringArray};
20062045
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
20072046
use futures::TryStreamExt;
20082047
use parquet::arrow::arrow_reader::{RowSelection, RowSelector};
@@ -5464,4 +5503,81 @@ message schema {
54645503
ts_array.value(0)
54655504
);
54665505
}
5506+
5507+
fn apply_predicate_to_batch(
5508+
predicate: Predicate,
5509+
schema: SchemaRef,
5510+
batch: RecordBatch,
5511+
) -> BooleanArray {
5512+
use super::PredicateConverter;
5513+
5514+
let bound = predicate.bind(schema, true).unwrap();
5515+
5516+
// Build a trivial Parquet schema with one float column at field id 4
5517+
let message_type = "
5518+
message schema {
5519+
optional float qux = 4;
5520+
}
5521+
";
5522+
let parquet_type = parse_message_type(message_type).expect("parse schema");
5523+
let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_type));
5524+
5525+
let column_map = HashMap::from([(4i32, 0usize)]);
5526+
let column_indices = vec![0usize];
5527+
5528+
let mut converter = PredicateConverter {
5529+
parquet_schema: &parquet_schema,
5530+
column_map: &column_map,
5531+
column_indices: &column_indices,
5532+
};
5533+
5534+
let mut predicate_fn = visit(&mut converter, &bound).unwrap();
5535+
predicate_fn(batch).unwrap()
5536+
}
5537+
5538+
#[test]
5539+
fn test_predicate_converter_nan() {
5540+
use arrow_array::Float32Array;
5541+
5542+
let schema = table_schema_simple();
5543+
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
5544+
"qux",
5545+
DataType::Float32,
5546+
true,
5547+
)]));
5548+
let values = vec![Some(1.0f32), Some(f32::NAN), None, Some(0.0f32)];
5549+
5550+
// is_nan: non-null-propagating per Java's implementation - NULL → false
5551+
let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(Float32Array::from(
5552+
values.clone(),
5553+
))])
5554+
.unwrap();
5555+
let result =
5556+
apply_predicate_to_batch(Reference::new("qux").is_nan(), schema.clone(), batch);
5557+
assert_eq!(
5558+
[
5559+
result.value(0),
5560+
result.value(1),
5561+
result.value(2),
5562+
result.value(3)
5563+
],
5564+
[false, true, false, false]
5565+
);
5566+
assert!(!result.is_null(2));
5567+
5568+
// not_nan: non-null-propagating per Java's implementation - NULL → true
5569+
let batch =
5570+
RecordBatch::try_new(arrow_schema, vec![Arc::new(Float32Array::from(values))]).unwrap();
5571+
let result = apply_predicate_to_batch(Reference::new("qux").is_not_nan(), schema, batch);
5572+
assert_eq!(
5573+
[
5574+
result.value(0),
5575+
result.value(1),
5576+
result.value(2),
5577+
result.value(3)
5578+
],
5579+
[true, false, true, true]
5580+
);
5581+
assert!(!result.is_null(2));
5582+
}
54675583
}

0 commit comments

Comments
 (0)