Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ harness = false
name = "replace"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "overlay"

[[bench]]
harness = false
name = "random"
Expand Down
68 changes: 68 additions & 0 deletions datafusion/functions/benches/overlay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod helper;

use arrow::datatypes::{DataType, Field};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use helper::gen_string_array;
use std::hint::black_box;
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
const N_ROWS: usize = 8192;
const STR_LEN: usize = 128;

let overlay = datafusion_functions::core::overlay();
let config_options = Arc::new(ConfigOptions::default());

let mut args = gen_string_array(N_ROWS, STR_LEN, 0.1, 0.5, false);
args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"DataFusion".to_string(),
))));
args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(32))));
args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(8))));

let arg_fields = args
.iter()
.enumerate()
.map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into())
.collect::<Vec<_>>();
let return_field = Arc::new(Field::new("f", DataType::Utf8, true));

c.bench_function("overlay_StringArray_utf8_scalar_args", |b| {
b.iter(|| {
black_box(
overlay
.invoke_with_args(ScalarFunctionArgs {
args: args.clone(),
arg_fields: arg_fields.clone(),
number_rows: N_ROWS,
return_field: Arc::clone(&return_field),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
252 changes: 142 additions & 110 deletions datafusion/functions/src/core/overlay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,140 +112,170 @@ impl ScalarUDFImpl for OverlayFunc {
}
}

/// Converts a 0-based character index into a byte index suitable for UTF-8
/// slicing.
fn byte_index_for_char(string: &str, char_idx: usize, is_ascii: bool) -> usize {
if is_ascii {
char_idx.min(string.len())
} else {
string
.char_indices()
.nth(char_idx)
.map_or(string.len(), |(byte_idx, _)| byte_idx)
}
}

/// Builds the OVERLAY result for a single (non-null) row.
///
/// `start_pos` is a 1-based character position; `replace_len` is the number
/// of characters of `string` to replace with `characters`.
fn overlay_one(
string: &str,
characters: &str,
start_pos: i64,
replace_len: i64,
) -> String {
debug_assert!(start_pos >= 1);

let is_ascii = string.is_ascii();
let string_char_len = if is_ascii {
string.len() as i64
} else {
string.chars().count() as i64
};

// Convert SQL's 1-based character position into 0-based character indexes.
// `start_char_idx` is the first replaced character; `end_char_idx` is the
// first character after the replaced span.
//
// No upper-bound check on `start_char_idx`: when it exceeds `string_char_len`
// we want the whole string as the prefix (PostgreSQL-compatible "insert past
// end" semantics).
let start_char_idx = start_pos - 1;
let end_char_idx = start_char_idx.saturating_add(replace_len);

let prefix_char_idx = usize::try_from(start_char_idx).unwrap_or(usize::MAX);
let prefix_end_byte = byte_index_for_char(string, prefix_char_idx, is_ascii);

let mut res = String::with_capacity(string.len() + characters.len());
res.push_str(&string[..prefix_end_byte]);
res.push_str(characters);

if end_char_idx < string_char_len {
let suffix_char_idx = usize::try_from(end_char_idx.max(0)).unwrap_or(usize::MAX);
let suffix_start_byte = byte_index_for_char(string, suffix_char_idx, is_ascii);
res.push_str(&string[suffix_start_byte..]);
}
res
}

macro_rules! process_overlay {
// For the three-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr) => {{
// Three argument case
($string_array:expr, $characters_array:expr, $pos_array:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
.iter()
.zip($characters_array.iter())
.zip($pos_array.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
if start_pos < 1 {
return exec_err!("negative substring length not allowed");
}
let replace_len = characters.chars().count() as i64;
Ok(Some(overlay_one(
string,
characters,
start_pos,
replace_len,
)))
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
_ => Ok(None),
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
})
.collect::<Result<GenericStringArray<T>>>()
}};

// For the four-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{
// Four argument case
($string_array:expr, $characters_array:expr, $pos_array:expr, $len_array:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.zip($len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
.iter()
.zip($characters_array.iter())
.zip($pos_array.iter())
.zip($len_array.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
if start_pos < 1 {
return exec_err!("negative substring length not allowed");
}
let string_char_len = string.chars().count() as i64;
Comment thread
neilconway marked this conversation as resolved.
Outdated
let replace_len = len.min(string_char_len);
Ok(Some(overlay_one(
string,
characters,
start_pos,
replace_len,
)))
}
Ok(Some(res))
_ => Ok(None),
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
})
.collect::<Result<GenericStringArray<T>>>()
}};
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
/// `OVERLAY(string PLACING substring FROM start [FOR count])`
///
/// Replaces a region of `string` with `substring`, starting at the 1-based
/// character position `start`. If `count` is supplied, that many characters
/// of `string` are replaced; otherwise `count` defaults to the character
/// length of `substring`.
///
/// ```text
/// overlay('Txxxxas' placing 'hom' from 2 for 4) → 'Thomas'
/// overlay('Txxxxas' placing 'hom' from 2) → 'Thomxas'
/// ```
fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let use_string_view = args[0].data_type() == &DataType::Utf8View;
if use_string_view {
if !matches!(args.len(), 3 | 4) {
return exec_err!(
"overlay was called with {} arguments. It requires 3 or 4.",
args.len()
);
}
if args[0].data_type() == &DataType::Utf8View {
string_view_overlay::<T>(args)
} else {
string_overlay::<T>(args)
}
}

fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_array = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
let result = if args.len() == 4 {
let len_array = as_int64_array(&args[3])?;
process_overlay!(string_array, characters_array, pos_array, len_array)?
} else {
process_overlay!(string_array, characters_array, pos_array)?
};
Ok(Arc::new(result) as ArrayRef)
}

fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_array = as_int64_array(&args[2])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
let result = if args.len() == 4 {
let len_array = as_int64_array(&args[3])?;
process_overlay!(string_array, characters_array, pos_array, len_array)?
} else {
process_overlay!(string_array, characters_array, pos_array)?
};
Ok(Arc::new(result) as ArrayRef)
}

#[cfg(test)]
Expand All @@ -265,7 +295,9 @@ mod tests {

let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
let result = as_generic_string_array::<i32>(&res).unwrap();
let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]);
// First row: start=4 is past the end of "123" (len 3). PostgreSQL
// takes the whole string as prefix and appends the replacement.
let expected = StringArray::from(vec!["123abc", "qwertyasdfg", "ijkz", "Thomas"]);
assert_eq!(&expected, result);

Ok(())
Expand Down
Loading
Loading