Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 279 additions & 9 deletions native/spark-expr/src/agg_funcs/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::{and, filter, is_not_null};

use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array};
use arrow::compute::{and, is_not_null};
use arrow::datatypes::{DataType, Field, FieldRef, Float64Type};
use std::{any::Any, sync::Arc};

use crate::agg_funcs::covariance::CovarianceAccumulator;
use crate::agg_funcs::covariance::{CovarianceAccumulator, CovarianceGroupsAccumulator};
use crate::agg_funcs::stddev::StddevAccumulator;
use arrow::datatypes::FieldRef;
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
};
use crate::agg_funcs::variance::VarianceGroupsAccumulator;
use arrow::array::AsArray;
use arrow::compute::filter;
use datafusion::common::{Result, ScalarValue};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::type_coercion::aggregates::NUMERICS;
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion::physical_expr::expressions::format_state_name;
use datafusion::physical_expr::expressions::StatsType;

Expand Down Expand Up @@ -118,6 +119,19 @@ impl AggregateUDFImpl for Correlation {
)),
])
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(CorrelationGroupsAccumulator::new(
self.null_on_divide_by_zero,
)))
}
}

/// An accumulator to compute correlation
Expand Down Expand Up @@ -248,3 +262,259 @@ impl Accumulator for CorrelationAccumulator {
+ self.stddev2.size()
}
}

/// Grouped correlation accumulator. Mirrors the per-row `CorrelationAccumulator`
/// composition: one population covariance + two population variance sub-
/// accumulators for the per-column m2 values. Combined filter handles
/// "skip rows where either input is null" without filtering `group_indices`.
#[derive(Debug)]
struct CorrelationGroupsAccumulator {
covar: CovarianceGroupsAccumulator,
var1: VarianceGroupsAccumulator,
var2: VarianceGroupsAccumulator,
null_on_divide_by_zero: bool,
}

impl CorrelationGroupsAccumulator {
fn new(null_on_divide_by_zero: bool) -> Self {
Self {
covar: CovarianceGroupsAccumulator::new(StatsType::Population, null_on_divide_by_zero),
var1: VarianceGroupsAccumulator::new(StatsType::Population, null_on_divide_by_zero),
var2: VarianceGroupsAccumulator::new(StatsType::Population, null_on_divide_by_zero),
null_on_divide_by_zero,
}
}
}

impl GroupsAccumulator for CorrelationGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to update_batch");
// Combine the caller's filter with a "neither column null" mask so
// both child accumulators see exactly the rows that should contribute.
let null_mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let combined: BooleanArray = match opt_filter {
Some(f) => and(f, &null_mask)?,
None => null_mask,
};

self.covar
.update_batch(values, group_indices, Some(&combined), total_num_groups)?;
self.var1.update_batch(
&values[0..1],
group_indices,
Some(&combined),
total_num_groups,
)?;
self.var2.update_batch(
&values[1..2],
group_indices,
Some(&combined),
total_num_groups,
)?;
Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 6, "six state columns to merge_batch");
// state column order: count, mean1, mean2, algo_const, m2_1, m2_2
let covar_state = [
Arc::clone(&values[0]),
Arc::clone(&values[1]),
Arc::clone(&values[2]),
Arc::clone(&values[3]),
];
let var1_state = [
Arc::clone(&values[0]),
Arc::clone(&values[1]),
Arc::clone(&values[4]),
];
let var2_state = [
Arc::clone(&values[0]),
Arc::clone(&values[2]),
Arc::clone(&values[5]),
];

self.covar
.merge_batch(&covar_state, group_indices, opt_filter, total_num_groups)?;
self.var1
.merge_batch(&var1_state, group_indices, opt_filter, total_num_groups)?;
self.var2
.merge_batch(&var2_state, group_indices, opt_filter, total_num_groups)?;
Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
// Snapshot per-group counts BEFORE the children's evaluate() consumes
// their state. This lets us apply the count==0 / count==1 branches the
// way the per-row CorrelationAccumulator does.
let counts: Vec<f64> = match emit_to {
EmitTo::All => self.covar.counts().to_vec(),
EmitTo::First(n) => self.covar.counts()[..n].to_vec(),
};

let covar = self.covar.evaluate(emit_to)?;
let var1 = self.var1.evaluate(emit_to)?;
let var2 = self.var2.evaluate(emit_to)?;
let covar = covar.as_primitive::<Float64Type>();
let var1 = var1.as_primitive::<Float64Type>();
let var2 = var2.as_primitive::<Float64Type>();

let n = covar.len();
let mut values = Vec::with_capacity(n);
let mut validity = Vec::with_capacity(n);
for (i, &count) in counts.iter().enumerate().take(n) {
if count == 0.0 {
values.push(0.0);
validity.push(false);
continue;
}
if count == 1.0 {
if self.null_on_divide_by_zero {
values.push(0.0);
validity.push(false);
} else {
values.push(f64::NAN);
validity.push(true);
}
continue;
}
if covar.is_null(i) || var1.is_null(i) || var2.is_null(i) {
values.push(0.0);
validity.push(false);
continue;
}
let c = covar.value(i);
let s1 = var1.value(i).sqrt();
let s2 = var2.value(i).sqrt();
if s1 == 0.0 || s2 == 0.0 {
values.push(0.0);
validity.push(false);
continue;
}
values.push(c / (s1 * s2));
validity.push(true);
}

Ok(Arc::new(Float64Array::new(
values.into(),
Some(arrow::buffer::NullBuffer::from(validity)),
)))
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
// covar.state -> [count, mean1, mean2, algo_const]
// var1.state -> [count, mean1, m2_1]
// var2.state -> [count, mean2, m2_2]
// Combined state (matches Correlation::state_fields):
// [count, mean1, mean2, algo_const, m2_1, m2_2]
let covar_state = self.covar.state(emit_to)?;
let var1_state = self.var1.state(emit_to)?;
let var2_state = self.var2.state(emit_to)?;
Ok(vec![
Arc::clone(&covar_state[0]),
Arc::clone(&covar_state[1]),
Arc::clone(&covar_state[2]),
Arc::clone(&covar_state[3]),
Arc::clone(&var1_state[2]),
Arc::clone(&var2_state[2]),
])
}

fn size(&self) -> usize {
self.covar.size() + self.var1.size() + self.var2.size()
}
}

#[cfg(test)]
mod groups_tests {
use super::*;
use arrow::array::AsArray;

fn acc(legacy: bool) -> CorrelationGroupsAccumulator {
// null_on_divide_by_zero = !legacy
CorrelationGroupsAccumulator::new(!legacy)
}

fn evaluate(a: &mut CorrelationGroupsAccumulator) -> Vec<Option<f64>> {
a.evaluate(EmitTo::All)
.unwrap()
.as_primitive::<Float64Type>()
.iter()
.collect()
}

#[test]
fn perfectly_correlated_single_group() {
let mut a = acc(true);
let v1: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
let v2: ArrayRef = Arc::new(Float64Array::from(vec![2.0, 4.0, 6.0, 8.0, 10.0]));
a.update_batch(&[v1, v2], &[0, 0, 0, 0, 0], None, 1)
.unwrap();
let r = evaluate(&mut a);
assert!((r[0].unwrap() - 1.0).abs() < 1e-12);
}

#[test]
fn either_column_null_dropped() {
let mut a = acc(true);
let v1: ArrayRef = Arc::new(Float64Array::from(vec![
Some(1.0),
None,
Some(3.0),
Some(5.0),
]));
let v2: ArrayRef = Arc::new(Float64Array::from(vec![
Some(2.0),
Some(99.0),
None,
Some(10.0),
]));
a.update_batch(&[v1, v2], &[0, 0, 0, 0], None, 1).unwrap();
// surviving pairs (1,2) and (5,10) lie on y=2x => corr 1.0
assert!((evaluate(&mut a)[0].unwrap() - 1.0).abs() < 1e-12);
}

#[test]
fn empty_group_yields_null() {
let mut a = acc(true);
let v1: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0]));
let v2: ArrayRef = Arc::new(Float64Array::from(vec![3.0, 6.0]));
a.update_batch(&[v1, v2], &[0, 0], None, 2).unwrap();
assert_eq!(evaluate(&mut a)[1], None);
}

#[test]
fn single_row_legacy_mode_yields_nan() {
// Correlation always uses Population stats internally. With one row
// the per-row CorrelationAccumulator returns NaN when in legacy
// (null_on_divide_by_zero=false) mode and null when the flag is set.
let mut a = acc(true); // legacy
let v1: ArrayRef = Arc::new(Float64Array::from(vec![42.0]));
let v2: ArrayRef = Arc::new(Float64Array::from(vec![7.0]));
a.update_batch(&[v1, v2], &[0], None, 1).unwrap();
let r = evaluate(&mut a);
assert!(r[0].unwrap().is_nan());
}

#[test]
fn single_row_ansi_mode_yields_null() {
let mut a = acc(false); // null_on_divide_by_zero = true
let v1: ArrayRef = Arc::new(Float64Array::from(vec![42.0]));
let v2: ArrayRef = Arc::new(Float64Array::from(vec![7.0]));
a.update_batch(&[v1, v2], &[0], None, 1).unwrap();
let r = evaluate(&mut a);
assert_eq!(r[0], None);
}
}
Loading
Loading