From 50c825682c48ef637f1bb65f01450a0b7d32d24f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 10:32:18 -0500 Subject: [PATCH 01/16] rename + add the rest of the arrow abi --- .../src/{extension.rs => extension_ffi.rs} | 0 c/sedona-extension/src/lib.rs | 2 +- c/sedona-extension/src/scalar_kernel.rs | 2 +- c/sedona-extension/src/sedona_extension.h | 355 ++++++++++++++++++ c/sedona-s2geography/src/kernels.rs | 2 +- 5 files changed, 358 insertions(+), 3 deletions(-) rename c/sedona-extension/src/{extension.rs => extension_ffi.rs} (100%) diff --git a/c/sedona-extension/src/extension.rs b/c/sedona-extension/src/extension_ffi.rs similarity index 100% rename from c/sedona-extension/src/extension.rs rename to c/sedona-extension/src/extension_ffi.rs diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index 073ed939bc..d5ded22efe 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub mod extension; +pub mod extension_ffi; pub mod scalar_kernel; diff --git a/c/sedona-extension/src/scalar_kernel.rs b/c/sedona-extension/src/scalar_kernel.rs index 78b5070b69..ae06243eaf 100644 --- a/c/sedona-extension/src/scalar_kernel.rs +++ b/c/sedona-extension/src/scalar_kernel.rs @@ -34,7 +34,7 @@ use std::{ str::FromStr, }; -use crate::extension::{ffi_arrow_schema_is_valid, SedonaCScalarKernel, SedonaCScalarKernelImpl}; +use crate::extension_ffi::{ffi_arrow_schema_is_valid, SedonaCScalarKernel, SedonaCScalarKernelImpl}; /// Wrapper around a [SedonaCScalarKernel] that implements [SedonaScalarKernel] /// diff --git a/c/sedona-extension/src/sedona_extension.h b/c/sedona-extension/src/sedona_extension.h index 7191af21f6..8274b9ee85 100644 --- a/c/sedona-extension/src/sedona_extension.h +++ b/c/sedona-extension/src/sedona_extension.h @@ -111,6 +111,361 @@ struct ArrowArrayStream { #endif // ARROW_C_STREAM_INTERFACE #endif // ARROW_FLAG_DICTIONARY_ORDERED +#ifndef ARROW_C_DEVICE_DATA_INTERFACE +#define ARROW_C_DEVICE_DATA_INTERFACE + +// Spec and Documentation: https://arrow.apache.org/docs/format/CDeviceDataInterface.html + +// DeviceType for the allocated memory +typedef int32_t ArrowDeviceType; + +// CPU device, same as using ArrowArray directly +#define ARROW_DEVICE_CPU 1 +// CUDA GPU Device +#define ARROW_DEVICE_CUDA 2 +// Pinned CUDA CPU memory by cudaMallocHost +#define ARROW_DEVICE_CUDA_HOST 3 +// OpenCL Device +#define ARROW_DEVICE_OPENCL 4 +// Vulkan buffer for next-gen graphics +#define ARROW_DEVICE_VULKAN 7 +// Metal for Apple GPU +#define ARROW_DEVICE_METAL 8 +// Verilog simulator buffer +#define ARROW_DEVICE_VPI 9 +// ROCm GPUs for AMD GPUs +#define ARROW_DEVICE_ROCM 10 +// Pinned ROCm CPU memory allocated by hipMallocHost +#define ARROW_DEVICE_ROCM_HOST 11 +// Reserved for extension +#define ARROW_DEVICE_EXT_DEV 12 +// CUDA managed/unified memory allocated by cudaMallocManaged +#define ARROW_DEVICE_CUDA_MANAGED 13 +// unified shared memory allocated on a oneAPI non-partitioned device. +#define ARROW_DEVICE_ONEAPI 14 +// GPU support for next-gen WebGPU standard +#define ARROW_DEVICE_WEBGPU 15 +// Qualcomm Hexagon DSP +#define ARROW_DEVICE_HEXAGON 16 + +struct ArrowDeviceArray { + // the Allocated Array + // + // the buffers in the array (along with the buffers of any + // children) are what is allocated on the device. + struct ArrowArray array; + // The device id to identify a specific device + int64_t device_id; + // The type of device which can access this memory. + ArrowDeviceType device_type; + // An event-like object to synchronize on if needed. + void* sync_event; + // Reserved bytes for future expansion. + int64_t reserved[3]; +}; + +#endif // ARROW_C_DEVICE_DATA_INTERFACE + +#ifndef ARROW_C_STREAM_INTERFACE +#define ARROW_C_STREAM_INTERFACE + +struct ArrowArrayStream { + // Callback to get the stream type + // (will be the same for all arrays in the stream). + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowArray must be released independently from the stream. + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowArrayStream*); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_STREAM_INTERFACE + +#ifndef ARROW_C_DEVICE_STREAM_INTERFACE +#define ARROW_C_DEVICE_STREAM_INTERFACE + +// Equivalent to ArrowArrayStream, but for ArrowDeviceArrays. +// +// This stream is intended to provide a stream of data on a single +// device, if a producer wants data to be produced on multiple devices +// then multiple streams should be provided. One per device. +struct ArrowDeviceArrayStream { + // The device that this stream produces data on. + ArrowDeviceType device_type; + + // Callback to get the stream schema + // (will be the same for all arrays in the stream). + // + // Return value 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + // The schema should be accessible via CPU memory. + int (*get_schema)(struct ArrowDeviceArrayStream* self, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowDeviceArray must be released independently from the stream. + int (*get_next)(struct ArrowDeviceArrayStream* self, struct ArrowDeviceArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowDeviceArrayStream* self); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowDeviceArrayStream* self); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DEVICE_STREAM_INTERFACE + +#ifndef ARROW_C_ASYNC_STREAM_INTERFACE +#define ARROW_C_ASYNC_STREAM_INTERFACE + +// EXPERIMENTAL: ArrowAsyncTask represents available data from a producer that was passed +// to an invocation of `on_next_task` on the ArrowAsyncDeviceStreamHandler. +// +// The reason for this Task approach instead of the Async interface returning +// the Array directly is to allow for more complex thread handling and reducing +// context switching and data transfers between CPU cores (e.g. from one L1/L2 +// cache to another) if desired. +// +// For example, the `on_next_task` callback can be called when data is ready, while +// the producer puts potential "decoding" logic in the `ArrowAsyncTask` object. This +// allows for the producer to manage the I/O on one thread which calls `on_next_task` +// and the consumer can determine when the decoding (producer logic in the `extract_data` +// callback of the task) occurs and on which thread, to avoid a CPU core transfer +// (data staying in the L2 cache). +struct ArrowAsyncTask { + // This callback should populate the ArrowDeviceArray associated with this task. + // The order of ArrowAsyncTasks provided by the producer enables a consumer to + // ensure the order of data to process. + // + // This function is expected to be synchronous, but should not perform any blocking + // I/O. Ideally it should be as cheap as possible so as to not tie up the consumer + // thread unnecessarily. + // + // Returns: 0 if successful, errno-compatible error otherwise. + // + // If a non-0 value is returned then it should be followed by a call to `on_error` + // on the appropriate ArrowAsyncDeviceStreamHandler. This is because it's highly + // likely that whatever is calling this function may be entirely disconnected from + // the current control flow. Indicating an error here with a non-zero return allows + // the current flow to be aware of the error occurring, while still allowing any + // logging or error handling to still be centralized in the `on_error` callback of + // the original Async handler. + // + // Rather than a release callback, any required cleanup should be performed as part + // of the invocation of `extract_data`. Ownership of the Array is passed to the consumer + // calling this, and so it must be released separately. + // + // It is only valid to call this method exactly once. + int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* out); + + // opaque task-specific data + void* private_data; +}; + +// EXPERIMENTAL: ArrowAsyncProducer represents a 1-to-1 relationship between an async +// producer and consumer. This object allows the consumer to perform backpressure and flow +// control on the asynchronous stream processing. This object must be owned by the +// producer who creates it, and thus is responsible for cleaning it up. +struct ArrowAsyncProducer { + // The device type that this stream produces data on. + ArrowDeviceType device_type; + + // A consumer must call this function to start receiving on_next_task calls. + // + // It *must* be valid to call this synchronously from within `on_next_task` or + // `on_schema`, but this function *must not* immediately call `on_next_task` so as + // to avoid recursion and reentrant callbacks. + // + // After cancel has been called, additional calls to this function must be NOPs, + // but allowed. While not cancelled, calling this function must register the + // given number of additional arrays/batches to be produced with the producer. + // The producer should only call `on_next_task` at most the registered number + // of arrays before propagating backpressure. + // + // Any error encountered by calling request must be propagated by calling the `on_error` + // callback of the ArrowAsyncDeviceStreamHandler. + // + // While not cancelled, any subsequent calls to `on_next_task`, `on_error` or + // `release` should be scheduled by the producer to be called later. + // + // It is invalid for a consumer to call this with a value of n <= 0, producers should + // error if given such a value. + void (*request)(struct ArrowAsyncProducer* self, int64_t n); + + // This cancel callback signals a producer that it must eventually stop making calls + // to on_next_task. It must be idempotent and thread-safe. After calling cancel once, + // subsequent calls must be NOPs. This must not call any consumer-side handlers other + // than `on_error`. + // + // It is not required that calling cancel affect the producer immediately, only that it + // must eventually stop calling on_next_task and subsequently call release on the + // async handler. As such, a consumer must be prepared to receive one or more calls to + // `on_next_task` even after calling cancel if there are still requested arrays pending. + // + // Successful cancellation should *not* result in the producer calling `on_error`, it + // should finish out any remaining tasks and eventually call `release`. + // + // Any error encountered during handling a call to cancel must be reported via the + // on_error callback on the async stream handler. + void (*cancel)(struct ArrowAsyncProducer* self); + + // Any additional metadata tied to a specific stream of data. This must either be NULL + // or a valid pointer to metadata which is encoded in the same way schema metadata + // would be. Non-null metadata must be valid for the lifetime of this object. As an + // example a producer could use this to provide the total number of rows and/or batches + // in the stream if known. + const char* additional_metadata; + + // producer-specific opaque data. + void* private_data; +}; + +// EXPERIMENTAL: Similar to ArrowDeviceArrayStream, except designed for an asynchronous +// style of interaction. While ArrowDeviceArrayStream provides producer +// defined callbacks, this is intended to be created by the consumer instead. +// The consumer passes this handler to the producer, which in turn uses the +// callbacks to inform the consumer of events in the stream. +struct ArrowAsyncDeviceStreamHandler { + // Handler for receiving a schema. The passed in stream_schema must be + // released or moved by the handler (producer is giving ownership of the schema to + // the handler, but not ownership of the top level object itself). + // + // With the exception of an error occurring (on_error), this must be the first + // callback function which is called by a producer and must only be called exactly + // once. As such, the producer should provide a valid ArrowAsyncProducer instance + // so the consumer can control the flow. See the documentation on ArrowAsyncProducer + // for how it works. The ArrowAsyncProducer is owned by the producer who calls this + // function and thus the producer is responsible for cleaning it up when calling + // the release callback of this handler. + // + // If there is any additional metadata tied to this stream, it will be provided as + // a non-null value for the `additional_metadata` field of the ArrowAsyncProducer + // which will be valid at least until the release callback is called. + // + // Return value: 0 if successful, `errno`-compatible error otherwise + // + // A producer that receives a non-zero return here should stop producing and eventually + // call release instead. + int (*on_schema)(struct ArrowAsyncDeviceStreamHandler* self, + struct ArrowSchema* stream_schema); + + // Handler for receiving data. This is called when data is available providing an + // ArrowAsyncTask struct to signify it. The producer indicates the end of the stream + // by passing NULL as the value for the task rather than a valid pointer to a task. + // The task object is only valid for the lifetime of this function call, if a consumer + // wants to utilize it after this function returns, it must copy or move the contents + // of it to a new ArrowAsyncTask object. + // + // The `request` callback of a provided ArrowAsyncProducer must be called in order + // to start receiving calls to this handler. + // + // The metadata argument can be null or can be used by a producer + // to pass arbitrary extra information to the consumer (such as total number + // of rows, context info, or otherwise). The data should be passed using the same + // encoding as the metadata within the ArrowSchema struct itself (defined in + // the spec at + // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata) + // + // If metadata is non-null then it only needs to exist for the lifetime of this call, + // a consumer who wants it to live after that must copy it to ensure lifetime. + // + // A producer *must not* call this concurrently from multiple different threads. + // + // A consumer must be prepared to receive one or more calls to this callback even + // after calling cancel on the corresponding ArrowAsyncProducer, as cancel does not + // guarantee it happens immediately. + // + // Return value: 0 if successful, `errno`-compatible error otherwise. + // + // If the consumer returns a non-zero return from this method, that indicates to the + // producer that it should stop propagating data as an error occurred. After receiving + // such a return, the only interaction with this object is for the producer to call + // the `release` callback. + int (*on_next_task)(struct ArrowAsyncDeviceStreamHandler* self, + struct ArrowAsyncTask* task, const char* metadata); + + // Handler for encountering an error. The producer should call release after + // this returns to clean up any resources. The `code` passed in can be any error + // code that a producer wants, but should be errno-compatible for consistency. + // + // If the message or metadata are non-null, they will only last as long as this + // function call. The consumer would need to perform a copy of the data if it is + // necessary for them to live past the lifetime of this call. + // + // Error metadata should be encoded as with metadata in ArrowSchema, defined in + // the spec at + // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata + // + // It is valid for this to be called by a producer with or without a preceding call + // to ArrowAsyncProducer.request. + // + // This callback must not call any methods of an ArrowAsyncProducer object. + void (*on_error)(struct ArrowAsyncDeviceStreamHandler* self, int code, + const char* message, const char* metadata); + + // Release callback to release any resources for the handler. Should always be + // called by a producer when it is done utilizing a handler. No callbacks should + // be called after this is called. + // + // It is valid for the release callback to be called by a producer with or without + // a preceding call to ArrowAsyncProducer.request. + // + // The release callback must not call any methods of an ArrowAsyncProducer object. + void (*release)(struct ArrowAsyncDeviceStreamHandler* self); + + // MUST be populated by the producer BEFORE calling any callbacks other than release. + // This provides the connection between a handler and its producer, and must exist until + // the release callback is called. + struct ArrowAsyncProducer* producer; + + // Opaque handler-specific data + void* private_data; +}; + +#endif // ARROW_C_ASYNC_STREAM_INTERFACE + /// \brief Simple ABI-stable scalar function implementation /// /// This object is not thread safe: callers must take care to serialize diff --git a/c/sedona-s2geography/src/kernels.rs b/c/sedona-s2geography/src/kernels.rs index 06e95b95cc..6ccda4089a 100644 --- a/c/sedona-s2geography/src/kernels.rs +++ b/c/sedona-s2geography/src/kernels.rs @@ -22,7 +22,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; use sedona_common::{sedona_internal_datafusion_err, sedona_internal_err}; use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel}; -use sedona_extension::{extension::SedonaCScalarKernel, scalar_kernel::ImportedScalarKernel}; +use sedona_extension::{extension_ffi::SedonaCScalarKernel, scalar_kernel::ImportedScalarKernel}; use sedona_functions::executor::WkbBytesExecutor; use sedona_schema::{ datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY}, From 3217d764ff8253e0ce2f633ed9ebf24fb2a057b8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 11:29:52 -0500 Subject: [PATCH 02/16] start --- .../src/device_stream_reader.rs | 137 ++++++++++++++++++ c/sedona-extension/src/extension_ffi.rs | 136 +++++++++++++++++ c/sedona-extension/src/lib.rs | 1 + c/sedona-extension/src/sedona_extension.h | 67 --------- 4 files changed, 274 insertions(+), 67 deletions(-) create mode 100644 c/sedona-extension/src/device_stream_reader.rs diff --git a/c/sedona-extension/src/device_stream_reader.rs b/c/sedona-extension/src/device_stream_reader.rs new file mode 100644 index 0000000000..afc8f24f09 --- /dev/null +++ b/c/sedona-extension/src/device_stream_reader.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::CStr; +use std::sync::Arc; + +use arrow_array::{ + ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}, + RecordBatch, RecordBatchReader, StructArray, +}; +use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; + +use crate::extension_ffi::{FFI_ArrowDeviceArray, FFI_ArrowDeviceArrayStream}; + +pub struct DeviceStreamReader { + inner: FFI_ArrowDeviceArrayStream, + schema: SchemaRef, + schema_struct_type: DataType, +} + +impl DeviceStreamReader { + pub fn try_new(mut inner: FFI_ArrowDeviceArrayStream) -> Result { + let get_schema = inner + .get_schema + .ok_or_else(|| ArrowError::CDataInterface("get_schema callback is null".to_string()))?; + + let mut ffi_schema = FFI_ArrowSchema::empty(); + let ret = unsafe { get_schema(&mut inner, &mut ffi_schema) }; + if ret != 0 { + let error_msg = Self::get_last_error_static(&mut inner); + return Err(ArrowError::CDataInterface(error_msg)); + } + + let schema = Schema::try_from(&ffi_schema)?; + let schema_struct_type = DataType::Struct(schema.fields().iter().cloned().collect()); + + Ok(Self { + inner, + schema: Arc::new(schema), + schema_struct_type, + }) + } + + fn get_last_error_static(inner: &mut FFI_ArrowDeviceArrayStream) -> String { + if let Some(get_last_error) = inner.get_last_error { + let err_ptr = unsafe { get_last_error(inner) }; + if !err_ptr.is_null() { + let c_str = unsafe { CStr::from_ptr(err_ptr) }; + return c_str.to_string_lossy().into_owned(); + } + } + "Unknown error".to_string() + } + + fn get_last_error(&mut self) -> String { + Self::get_last_error_static(&mut self.inner) + } +} + +impl Iterator for DeviceStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + let Some(get_next) = self.inner.get_next else { + return Some(Err(ArrowError::CDataInterface( + "get_next() is null".to_string(), + ))); + }; + + let mut device_array = FFI_ArrowDeviceArray { + array: FFI_ArrowArray::empty(), + device_id: 0, + device_type: 0, + sync_event: std::ptr::null_mut(), + }; + + let ret = unsafe { get_next(&mut self.inner, &mut device_array) }; + if ret != 0 { + return Some(Err(ArrowError::CDataInterface(self.get_last_error()))); + } + + // Check if the stream is exhausted (release is null means empty/end of stream) + if device_array.array.is_released() { + return None; + } + + // Convert device array to regular array (only supports CPU for now) + let ffi_array: FFI_ArrowArray = match device_array.try_into() { + Ok(arr) => arr, + Err(e) => return Some(Err(e)), + }; + + // Import the array data + let array_data = + match unsafe { from_ffi_and_data_type(ffi_array, self.schema_struct_type.clone()) } { + Ok(array_data) => array_data, + Err(e) => return Some(Err(e)), + }; + + // Create RecordBatch from StructArray + let struct_array: StructArray = array_data.into(); + Some(Ok(struct_array.into())) + } +} + +impl RecordBatchReader for DeviceStreamReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +struct ExportedDeviceReaderPrivate { + inner: Box +} + +fn exported_reader_get_schema( + self_: *mut FFI_ArrowDeviceArrayStream, + out: *mut FFI_ArrowSchema, + ) -> std::ffi::c_int { + + } + + diff --git a/c/sedona-extension/src/extension_ffi.rs b/c/sedona-extension/src/extension_ffi.rs index 7d957a57c2..2604793170 100644 --- a/c/sedona-extension/src/extension_ffi.rs +++ b/c/sedona-extension/src/extension_ffi.rs @@ -22,6 +22,7 @@ use std::{ }; use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; +use arrow_schema::ArrowError; /// Raw FFI representation of the SedonaCScalarKernel /// @@ -120,3 +121,138 @@ struct ArrowSchemaInternal { release: Option, private_data: *mut c_void, } + +/// Constant for the CPU device +pub const ARROW_DEVICE_CPU: i32 = 1; + +/// FFI representation of the ArrowDeviceArray from the Arrow C Device Data Interface +/// +/// Defined here because it is not yet defined in arrow-rs and is needed for the async +/// array stream. +/// +/// Note that implementing drop() isn't needed here because Rust will drop() the array. +/// +/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html +#[repr(C)] +pub struct FFI_ArrowDeviceArray { + pub array: FFI_ArrowArray, + pub device_id: i64, + pub device_type: i32, + pub sync_event: *mut c_void, +} + +impl TryFrom for FFI_ArrowArray { + type Error = ArrowError; + + fn try_from(value: FFI_ArrowDeviceArray) -> Result { + if value.device_id != 1 { + return Err(ArrowError::CDataInterface( + "Can't create FFI_ArrowArray from non-CPU FFI_ArrowDeviceArray".to_string(), + )); + } + + if !value.sync_event.is_null() { + return Err(ArrowError::CDataInterface( + "Can't create FFI_ArrowArray from FFI_ArrowDeviceArray with non-null sync event" + .to_string(), + )); + } + + Ok(value.array) + } +} + +impl From for FFI_ArrowDeviceArray { + fn from(value: FFI_ArrowArray) -> Self { + FFI_ArrowDeviceArray { + array: value, + device_id: -1, + device_type: 1, + sync_event: std::ptr::null_mut(), + } + } +} + +/// FFI representation of the ArrowDeviceArrayStream from the Arrow C Device Data Interface +/// +/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html +#[repr(C)] +pub struct FFI_ArrowDeviceArrayStream { + pub device_type: i32, + pub get_schema: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowDeviceArrayStream, + out: *mut FFI_ArrowSchema, + ) -> c_int, + >, + pub get_next: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowDeviceArrayStream, + out: *mut FFI_ArrowDeviceArray, + ) -> c_int, + >, + pub get_last_error: + Option *const c_char>, + pub release: Option, + pub private_data: *mut c_void, +} + +impl Drop for FFI_ArrowDeviceArrayStream { + fn drop(&mut self) { + if let Some(releaser) = self.release { + unsafe { releaser(self) }; + } + } +} + +/// FFI representation of the ArrowAsyncDeviceStreamHandler from the Arrow C Device Data Interface +/// +/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html +#[repr(C)] +pub struct FFI_ArrowAsyncDeviceStreamHandler { + pub on_schema: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + schema: *mut FFI_ArrowSchema, + ) -> c_int, + >, + pub on_next_task: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + task: *mut FFI_ArrowAsyncTask, + metadata: *const c_char, + ) -> c_int, + >, + pub on_error: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + code: c_int, + message: *const c_char, + metadata: *const c_char, + ), + >, + pub release: Option, + pub private_data: *mut c_void, +} + +impl Drop for FFI_ArrowAsyncDeviceStreamHandler { + fn drop(&mut self) { + if let Some(releaser) = self.release { + unsafe { releaser(self) }; + } + } +} + +/// FFI representation of the ArrowAsyncTask from the Arrow C Device Data Interface +/// +/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html +#[repr(C)] +pub struct FFI_ArrowAsyncTask { + pub extract_data: Option< + unsafe extern "C" fn( + self_: *mut FFI_ArrowAsyncTask, + out: *mut FFI_ArrowDeviceArray, + ) -> c_int, + >, + pub private_data: *mut c_void, +} diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index d5ded22efe..0ac89cdc2e 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -17,3 +17,4 @@ pub mod extension_ffi; pub mod scalar_kernel; +pub mod device_stream_reader; diff --git a/c/sedona-extension/src/sedona_extension.h b/c/sedona-extension/src/sedona_extension.h index 8274b9ee85..4e3c587094 100644 --- a/c/sedona-extension/src/sedona_extension.h +++ b/c/sedona-extension/src/sedona_extension.h @@ -121,32 +121,6 @@ typedef int32_t ArrowDeviceType; // CPU device, same as using ArrowArray directly #define ARROW_DEVICE_CPU 1 -// CUDA GPU Device -#define ARROW_DEVICE_CUDA 2 -// Pinned CUDA CPU memory by cudaMallocHost -#define ARROW_DEVICE_CUDA_HOST 3 -// OpenCL Device -#define ARROW_DEVICE_OPENCL 4 -// Vulkan buffer for next-gen graphics -#define ARROW_DEVICE_VULKAN 7 -// Metal for Apple GPU -#define ARROW_DEVICE_METAL 8 -// Verilog simulator buffer -#define ARROW_DEVICE_VPI 9 -// ROCm GPUs for AMD GPUs -#define ARROW_DEVICE_ROCM 10 -// Pinned ROCm CPU memory allocated by hipMallocHost -#define ARROW_DEVICE_ROCM_HOST 11 -// Reserved for extension -#define ARROW_DEVICE_EXT_DEV 12 -// CUDA managed/unified memory allocated by cudaMallocManaged -#define ARROW_DEVICE_CUDA_MANAGED 13 -// unified shared memory allocated on a oneAPI non-partitioned device. -#define ARROW_DEVICE_ONEAPI 14 -// GPU support for next-gen WebGPU standard -#define ARROW_DEVICE_WEBGPU 15 -// Qualcomm Hexagon DSP -#define ARROW_DEVICE_HEXAGON 16 struct ArrowDeviceArray { // the Allocated Array @@ -166,47 +140,6 @@ struct ArrowDeviceArray { #endif // ARROW_C_DEVICE_DATA_INTERFACE -#ifndef ARROW_C_STREAM_INTERFACE -#define ARROW_C_STREAM_INTERFACE - -struct ArrowArrayStream { - // Callback to get the stream type - // (will be the same for all arrays in the stream). - // - // Return value: 0 if successful, an `errno`-compatible error code otherwise. - // - // If successful, the ArrowSchema must be released independently from the stream. - int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); - - // Callback to get the next array - // (if no error and the array is released, the stream has ended) - // - // Return value: 0 if successful, an `errno`-compatible error code otherwise. - // - // If successful, the ArrowArray must be released independently from the stream. - int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); - - // Callback to get optional detailed error information. - // This must only be called if the last stream operation failed - // with a non-0 return code. - // - // Return value: pointer to a null-terminated character array describing - // the last error, or NULL if no description is available. - // - // The returned pointer is only valid until the next operation on this stream - // (including release). - const char* (*get_last_error)(struct ArrowArrayStream*); - - // Release callback: release the stream's own resources. - // Note that arrays returned by `get_next` must be individually released. - void (*release)(struct ArrowArrayStream*); - - // Opaque producer-specific data - void* private_data; -}; - -#endif // ARROW_C_STREAM_INTERFACE - #ifndef ARROW_C_DEVICE_STREAM_INTERFACE #define ARROW_C_DEVICE_STREAM_INTERFACE From 95dcfc96da003e7c9036e9ef49650f2e21b96f43 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 11:49:32 -0500 Subject: [PATCH 03/16] roundtrip test --- .../src/device_stream_reader.rs | 254 +++++++++++++++++- c/sedona-extension/src/extension_ffi.rs | 2 +- 2 files changed, 250 insertions(+), 6 deletions(-) diff --git a/c/sedona-extension/src/device_stream_reader.rs b/c/sedona-extension/src/device_stream_reader.rs index afc8f24f09..538a68346e 100644 --- a/c/sedona-extension/src/device_stream_reader.rs +++ b/c/sedona-extension/src/device_stream_reader.rs @@ -20,11 +20,12 @@ use std::sync::Arc; use arrow_array::{ ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}, + ffi_stream::FFI_ArrowArrayStream, RecordBatch, RecordBatchReader, StructArray, }; use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; -use crate::extension_ffi::{FFI_ArrowDeviceArray, FFI_ArrowDeviceArrayStream}; +use crate::extension_ffi::{FFI_ArrowDeviceArray, FFI_ArrowDeviceArrayStream, ARROW_DEVICE_CPU}; pub struct DeviceStreamReader { inner: FFI_ArrowDeviceArrayStream, @@ -123,15 +124,258 @@ impl RecordBatchReader for DeviceStreamReader { } } +impl From for FFI_ArrowDeviceArrayStream { + fn from(value: FFI_ArrowArrayStream) -> Self { + let private_data = Box::new(ExportedDeviceReaderPrivate { inner: value }); + + FFI_ArrowDeviceArrayStream { + device_type: ARROW_DEVICE_CPU, + get_schema: Some(exported_reader_get_schema), + get_next: Some(exported_reader_get_next), + get_last_error: Some(exported_reader_get_last_error), + release: Some(exported_reader_release), + private_data: Box::into_raw(private_data) as *mut std::ffi::c_void, + } + } +} + struct ExportedDeviceReaderPrivate { - inner: Box + inner: FFI_ArrowArrayStream, } -fn exported_reader_get_schema( - self_: *mut FFI_ArrowDeviceArrayStream, +unsafe extern "C" fn exported_reader_get_schema( + self_: *mut FFI_ArrowDeviceArrayStream, + out: *mut FFI_ArrowSchema, +) -> std::ffi::c_int { + if self_.is_null() || out.is_null() { + return 1; + } + + let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); + let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; + + if let Some(get_schema) = (*inner).get_schema { + get_schema(inner, out) + } else { + 1 + } +} + +unsafe extern "C" fn exported_reader_get_next( + self_: *mut FFI_ArrowDeviceArrayStream, + out: *mut FFI_ArrowDeviceArray, +) -> std::ffi::c_int { + if self_.is_null() || out.is_null() { + return 1; + } + + let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); + let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; + + let Some(get_next) = (*inner).get_next else { + return 1; + }; + + let mut ffi_array = FFI_ArrowArray::empty(); + let ret = get_next(inner, &mut ffi_array); + if ret != 0 { + return ret; + } + + // Wrap the array in a device array + let device_array = FFI_ArrowDeviceArray { + array: ffi_array, + device_id: -1, + device_type: ARROW_DEVICE_CPU, + sync_event: std::ptr::null_mut(), + }; + std::ptr::write(out, device_array); + 0 +} + +unsafe extern "C" fn exported_reader_get_last_error( + self_: *mut FFI_ArrowDeviceArrayStream, +) -> *const std::ffi::c_char { + if self_.is_null() { + return std::ptr::null(); + } + + let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); + let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; + + if let Some(get_last_error) = (*inner).get_last_error { + get_last_error(inner) + } else { + std::ptr::null() + } +} + +unsafe extern "C" fn exported_reader_release(self_: *mut FFI_ArrowDeviceArrayStream) { + if self_.is_null() { + return; + } + + let stream = &mut *self_; + if stream.private_data.is_null() { + return; + } + + // Drop the private data (which will drop the inner FFI_ArrowArrayStream) + let _ = Box::from_raw(stream.private_data as *mut ExportedDeviceReaderPrivate); + stream.private_data = std::ptr::null_mut(); + stream.release = None; +} + +/// Internal representation of FFI_ArrowArrayStream with public fields. +/// Duplicated here because arrow-rs doesn't expose the fields publicly. +#[repr(C)] +struct ArrowArrayStreamInternal { + get_schema: Option< + unsafe extern "C" fn( + self_: *mut ArrowArrayStreamInternal, out: *mut FFI_ArrowSchema, - ) -> std::ffi::c_int { + ) -> std::ffi::c_int, + >, + get_next: Option< + unsafe extern "C" fn( + self_: *mut ArrowArrayStreamInternal, + out: *mut FFI_ArrowArray, + ) -> std::ffi::c_int, + >, + get_last_error: Option< + unsafe extern "C" fn(self_: *mut ArrowArrayStreamInternal) -> *const std::ffi::c_char, + >, + release: Option, + private_data: *mut std::ffi::c_void, +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Int32Array, RecordBatchIterator, StringArray}; + use arrow_schema::Field; + fn make_test_batches() -> (SchemaRef, Vec) { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])), + ], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![4, 5])), + Arc::new(StringArray::from(vec![Some("d"), Some("e")])), + ], + ) + .unwrap(); + + (schema, vec![batch1, batch2]) + } + + #[test] + fn test_roundtrip_two_batches() { + let (schema, batches) = make_test_batches(); + let original_batches = batches.clone(); + + // Create a RecordBatchReader + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + + // Export to FFI_ArrowArrayStream, then to FFI_ArrowDeviceArrayStream + let array_stream = FFI_ArrowArrayStream::new(Box::new(reader)); + let device_stream: FFI_ArrowDeviceArrayStream = array_stream.into(); + + // Import back via DeviceStreamReader + let imported_reader = DeviceStreamReader::try_new(device_stream).unwrap(); + + // Verify schema + assert_eq!(imported_reader.schema(), schema); + + // Collect and verify batches + let imported_batches: Vec = + imported_reader.into_iter().map(|r| r.unwrap()).collect(); + + assert_eq!(imported_batches.len(), original_batches.len()); + for (imported, original) in imported_batches.iter().zip(original_batches.iter()) { + assert_eq!(imported, original); + } + } + + /// A RecordBatchReader that yields one batch then errors + struct ErroringReader { + schema: SchemaRef, + yielded_first: bool, + first_batch: Option, + } + + impl ErroringReader { + fn new(schema: SchemaRef, first_batch: RecordBatch) -> Self { + Self { + schema, + yielded_first: false, + first_batch: Some(first_batch), + } + } + } + + impl Iterator for ErroringReader { + type Item = Result; + + fn next(&mut self) -> Option { + if !self.yielded_first { + self.yielded_first = true; + Some(Ok(self.first_batch.take().unwrap())) + } else { + Some(Err(ArrowError::ComputeError( + "intentional test error".to_string(), + ))) + } + } + } + + impl RecordBatchReader for ErroringReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() } + } + + #[test] + fn test_roundtrip_with_error() { + let (schema, batches) = make_test_batches(); + let first_batch = batches[0].clone(); + + // Create an erroring reader + let reader = ErroringReader::new(schema.clone(), first_batch.clone()); + // Export to FFI_ArrowArrayStream, then to FFI_ArrowDeviceArrayStream + let array_stream = FFI_ArrowArrayStream::new(Box::new(reader)); + let device_stream: FFI_ArrowDeviceArrayStream = array_stream.into(); + // Import back via DeviceStreamReader + let mut imported_reader = DeviceStreamReader::try_new(device_stream).unwrap(); + + // First batch should succeed + let result1 = imported_reader.next(); + assert!(result1.is_some()); + let batch1 = result1.unwrap(); + assert!(batch1.is_ok()); + assert_eq!(batch1.unwrap(), first_batch); + + // Second call should return an error + let result2 = imported_reader.next(); + assert!(result2.is_some()); + let batch2 = result2.unwrap(); + assert!(batch2.is_err()); + let err = batch2.unwrap_err(); + assert!(err.to_string().contains("intentional test error")); + } +} diff --git a/c/sedona-extension/src/extension_ffi.rs b/c/sedona-extension/src/extension_ffi.rs index 2604793170..63dd8a9e51 100644 --- a/c/sedona-extension/src/extension_ffi.rs +++ b/c/sedona-extension/src/extension_ffi.rs @@ -145,7 +145,7 @@ impl TryFrom for FFI_ArrowArray { type Error = ArrowError; fn try_from(value: FFI_ArrowDeviceArray) -> Result { - if value.device_id != 1 { + if value.device_type != ARROW_DEVICE_CPU { return Err(ArrowError::CDataInterface( "Can't create FFI_ArrowArray from non-CPU FFI_ArrowDeviceArray".to_string(), )); From 6150559f5b1aaefeadbaaceda37c80b75c6f0d92 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 12:22:04 -0500 Subject: [PATCH 04/16] sendable stream one direction --- Cargo.lock | 2 + c/sedona-extension/Cargo.toml | 6 + c/sedona-extension/src/lib.rs | 5 +- .../src/sendable_record_batch_stream.rs | 203 ++++++++++++++++++ 4 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 c/sedona-extension/src/sendable_record_batch_stream.rs diff --git a/Cargo.lock b/Cargo.lock index 95b4343453..1cd934ecd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5782,7 +5782,9 @@ dependencies = [ "arrow-array", "arrow-schema", "datafusion-common", + "datafusion-execution", "datafusion-expr", + "futures", "libc", "sedona-common", "sedona-expr", diff --git a/c/sedona-extension/Cargo.toml b/c/sedona-extension/Cargo.toml index e8f8c17c56..3d27389288 100644 --- a/c/sedona-extension/Cargo.toml +++ b/c/sedona-extension/Cargo.toml @@ -27,11 +27,17 @@ readme.workspace = true edition.workspace = true rust-version.workspace = true +[features] +default = [] +async = ["dep:datafusion-execution", "dep:futures"] + [dependencies] arrow-array = { workspace = true, features = ["ffi"]} arrow-schema = { workspace = true, features = ["ffi"]} datafusion-common = { workspace = true } +datafusion-execution = { workspace = true, optional = true } datafusion-expr = { workspace = true } +futures = { workspace = true, optional = true } libc = "0.2.178" sedona-common = { workspace = true } sedona-expr = { workspace = true } diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index 0ac89cdc2e..37b86c847b 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +pub mod device_stream_reader; pub mod extension_ffi; pub mod scalar_kernel; -pub mod device_stream_reader; + +#[cfg(feature = "async")] +pub mod sendable_record_batch_stream; diff --git a/c/sedona-extension/src/sendable_record_batch_stream.rs b/c/sedona-extension/src/sendable_record_batch_stream.rs new file mode 100644 index 0000000000..fd604b0cf3 --- /dev/null +++ b/c/sedona-extension/src/sendable_record_batch_stream.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FFI wrapper for [`SendableRecordBatchStream`] using the Arrow C Device Data Interface. +//! +//! This module provides interoperability between DataFusion's async record batch streams +//! and C code via the [`FFI_ArrowAsyncDeviceStreamHandler`] interface. + +use std::ffi::{c_int, c_void, CString}; +use std::ptr::null_mut; + +use arrow_array::ffi::FFI_ArrowSchema; +use arrow_array::RecordBatch; +use futures::StreamExt; + +use datafusion_execution::SendableRecordBatchStream; + +use crate::extension_ffi::{ + FFI_ArrowAsyncDeviceStreamHandler, FFI_ArrowAsyncTask, FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, +}; + +/// Drives a [`SendableRecordBatchStream`] and pushes results to an [`FFI_ArrowAsyncDeviceStreamHandler`]. +/// +/// This function consumes the stream and calls the handler's callbacks: +/// - `on_schema` is called first with the stream's schema +/// - `on_next_task` is called for each record batch +/// - `on_error` is called if an error occurs +/// +/// # Safety +/// +/// The `handler` pointer must be valid and point to a properly initialized +/// [`FFI_ArrowAsyncDeviceStreamHandler`]. +/// +/// # Example +/// +/// ```ignore +/// use sedona_extension::sendable_record_batch_stream::drive_stream_to_handler; +/// +/// // Assuming you have a SendableRecordBatchStream and a handler +/// drive_stream_to_handler(stream, handler).await; +/// ``` +pub async fn drive_stream_to_handler( + mut stream: SendableRecordBatchStream, + handler: *mut FFI_ArrowAsyncDeviceStreamHandler, +) { + if handler.is_null() { + return; + } + + let handler_ref = unsafe { &mut *handler }; + + // First, send the schema + let schema = stream.schema(); + let ffi_schema = match FFI_ArrowSchema::try_from(schema.as_ref()) { + Ok(s) => s, + Err(e) => { + call_on_error(handler_ref, 1, &e.to_string()); + return; + } + }; + + if let Some(on_schema) = handler_ref.on_schema { + // We need to box the schema so it lives long enough + let mut boxed_schema = Box::new(ffi_schema); + let ret = unsafe { on_schema(handler, boxed_schema.as_mut()) }; + if ret != 0 { + call_on_error(handler_ref, ret, "on_schema callback failed"); + return; + } + // The handler now owns the schema, so we forget it + std::mem::forget(boxed_schema); + } + + // Stream batches + while let Some(result) = stream.next().await { + match result { + Ok(batch) => { + if let Err(e) = send_batch_to_handler(handler_ref, handler, batch) { + call_on_error(handler_ref, 1, &e); + return; + } + } + Err(e) => { + call_on_error(handler_ref, 1, &e.to_string()); + return; + } + } + } +} + +fn send_batch_to_handler( + handler_ref: &mut FFI_ArrowAsyncDeviceStreamHandler, + handler: *mut FFI_ArrowAsyncDeviceStreamHandler, + batch: RecordBatch, +) -> Result<(), String> { + let Some(on_next_task) = handler_ref.on_next_task else { + return Err("on_next_task callback is null".to_string()); + }; + + // Create a task that wraps the batch + let task = create_async_task(batch)?; + let mut boxed_task = Box::new(task); + + let ret = unsafe { on_next_task(handler, boxed_task.as_mut(), std::ptr::null()) }; + if ret != 0 { + return Err("on_next_task callback failed".to_string()); + } + + // The handler now owns the task + std::mem::forget(boxed_task); + Ok(()) +} + +fn call_on_error(handler: &mut FFI_ArrowAsyncDeviceStreamHandler, code: c_int, message: &str) { + if let Some(on_error) = handler.on_error { + let c_message = CString::new(message).unwrap_or_else(|_| CString::new("error").unwrap()); + unsafe { + on_error( + handler, + code, + c_message.as_ptr(), + std::ptr::null(), // no metadata + ); + } + } +} + +/// Private data for an async task wrapping a RecordBatch +struct AsyncTaskPrivate { + batch: Option, +} + +fn create_async_task(batch: RecordBatch) -> Result { + let private_data = Box::new(AsyncTaskPrivate { batch: Some(batch) }); + + Ok(FFI_ArrowAsyncTask { + extract_data: Some(async_task_extract_data), + private_data: Box::into_raw(private_data) as *mut c_void, + }) +} + +unsafe extern "C" fn async_task_extract_data( + self_: *mut FFI_ArrowAsyncTask, + out: *mut FFI_ArrowDeviceArray, +) -> c_int { + if self_.is_null() || out.is_null() { + return 1; + } + + let task = &mut *self_; + if task.private_data.is_null() { + return 1; + } + + let private = &mut *(task.private_data as *mut AsyncTaskPrivate); + + // Take the batch (can only extract once) + let Some(batch) = private.batch.take() else { + return 1; + }; + + // Convert to struct array and then to FFI + let struct_array: arrow_array::StructArray = batch.into(); + let (ffi_array, _ffi_schema) = match arrow_array::ffi::to_ffi(&struct_array.into()) { + Ok(result) => result, + Err(_) => return 1, + }; + + // Create device array (CPU) + let device_array = FFI_ArrowDeviceArray { + array: ffi_array, + device_id: -1, + device_type: ARROW_DEVICE_CPU, + sync_event: null_mut(), + }; + + std::ptr::write(out, device_array); + 0 +} + +impl Drop for FFI_ArrowAsyncTask { + fn drop(&mut self) { + if !self.private_data.is_null() { + // Drop the private data + let _ = unsafe { Box::from_raw(self.private_data as *mut AsyncTaskPrivate) }; + self.private_data = null_mut(); + } + } +} From b6c32da703297cb13213f50ace93fa4fcdfe17f0 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 12:46:15 -0500 Subject: [PATCH 05/16] first pass at "exporting" a stream --- .../src/device_stream_reader.rs | 381 ------------------ c/sedona-extension/src/extension_ffi.rs | 41 ++ c/sedona-extension/src/lib.rs | 1 - .../src/sendable_record_batch_stream.rs | 281 +++++++++++-- 4 files changed, 293 insertions(+), 411 deletions(-) delete mode 100644 c/sedona-extension/src/device_stream_reader.rs diff --git a/c/sedona-extension/src/device_stream_reader.rs b/c/sedona-extension/src/device_stream_reader.rs deleted file mode 100644 index 538a68346e..0000000000 --- a/c/sedona-extension/src/device_stream_reader.rs +++ /dev/null @@ -1,381 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::ffi::CStr; -use std::sync::Arc; - -use arrow_array::{ - ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}, - ffi_stream::FFI_ArrowArrayStream, - RecordBatch, RecordBatchReader, StructArray, -}; -use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; - -use crate::extension_ffi::{FFI_ArrowDeviceArray, FFI_ArrowDeviceArrayStream, ARROW_DEVICE_CPU}; - -pub struct DeviceStreamReader { - inner: FFI_ArrowDeviceArrayStream, - schema: SchemaRef, - schema_struct_type: DataType, -} - -impl DeviceStreamReader { - pub fn try_new(mut inner: FFI_ArrowDeviceArrayStream) -> Result { - let get_schema = inner - .get_schema - .ok_or_else(|| ArrowError::CDataInterface("get_schema callback is null".to_string()))?; - - let mut ffi_schema = FFI_ArrowSchema::empty(); - let ret = unsafe { get_schema(&mut inner, &mut ffi_schema) }; - if ret != 0 { - let error_msg = Self::get_last_error_static(&mut inner); - return Err(ArrowError::CDataInterface(error_msg)); - } - - let schema = Schema::try_from(&ffi_schema)?; - let schema_struct_type = DataType::Struct(schema.fields().iter().cloned().collect()); - - Ok(Self { - inner, - schema: Arc::new(schema), - schema_struct_type, - }) - } - - fn get_last_error_static(inner: &mut FFI_ArrowDeviceArrayStream) -> String { - if let Some(get_last_error) = inner.get_last_error { - let err_ptr = unsafe { get_last_error(inner) }; - if !err_ptr.is_null() { - let c_str = unsafe { CStr::from_ptr(err_ptr) }; - return c_str.to_string_lossy().into_owned(); - } - } - "Unknown error".to_string() - } - - fn get_last_error(&mut self) -> String { - Self::get_last_error_static(&mut self.inner) - } -} - -impl Iterator for DeviceStreamReader { - type Item = Result; - - fn next(&mut self) -> Option { - let Some(get_next) = self.inner.get_next else { - return Some(Err(ArrowError::CDataInterface( - "get_next() is null".to_string(), - ))); - }; - - let mut device_array = FFI_ArrowDeviceArray { - array: FFI_ArrowArray::empty(), - device_id: 0, - device_type: 0, - sync_event: std::ptr::null_mut(), - }; - - let ret = unsafe { get_next(&mut self.inner, &mut device_array) }; - if ret != 0 { - return Some(Err(ArrowError::CDataInterface(self.get_last_error()))); - } - - // Check if the stream is exhausted (release is null means empty/end of stream) - if device_array.array.is_released() { - return None; - } - - // Convert device array to regular array (only supports CPU for now) - let ffi_array: FFI_ArrowArray = match device_array.try_into() { - Ok(arr) => arr, - Err(e) => return Some(Err(e)), - }; - - // Import the array data - let array_data = - match unsafe { from_ffi_and_data_type(ffi_array, self.schema_struct_type.clone()) } { - Ok(array_data) => array_data, - Err(e) => return Some(Err(e)), - }; - - // Create RecordBatch from StructArray - let struct_array: StructArray = array_data.into(); - Some(Ok(struct_array.into())) - } -} - -impl RecordBatchReader for DeviceStreamReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl From for FFI_ArrowDeviceArrayStream { - fn from(value: FFI_ArrowArrayStream) -> Self { - let private_data = Box::new(ExportedDeviceReaderPrivate { inner: value }); - - FFI_ArrowDeviceArrayStream { - device_type: ARROW_DEVICE_CPU, - get_schema: Some(exported_reader_get_schema), - get_next: Some(exported_reader_get_next), - get_last_error: Some(exported_reader_get_last_error), - release: Some(exported_reader_release), - private_data: Box::into_raw(private_data) as *mut std::ffi::c_void, - } - } -} - -struct ExportedDeviceReaderPrivate { - inner: FFI_ArrowArrayStream, -} - -unsafe extern "C" fn exported_reader_get_schema( - self_: *mut FFI_ArrowDeviceArrayStream, - out: *mut FFI_ArrowSchema, -) -> std::ffi::c_int { - if self_.is_null() || out.is_null() { - return 1; - } - - let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); - let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; - - if let Some(get_schema) = (*inner).get_schema { - get_schema(inner, out) - } else { - 1 - } -} - -unsafe extern "C" fn exported_reader_get_next( - self_: *mut FFI_ArrowDeviceArrayStream, - out: *mut FFI_ArrowDeviceArray, -) -> std::ffi::c_int { - if self_.is_null() || out.is_null() { - return 1; - } - - let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); - let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; - - let Some(get_next) = (*inner).get_next else { - return 1; - }; - - let mut ffi_array = FFI_ArrowArray::empty(); - let ret = get_next(inner, &mut ffi_array); - if ret != 0 { - return ret; - } - - // Wrap the array in a device array - let device_array = FFI_ArrowDeviceArray { - array: ffi_array, - device_id: -1, - device_type: ARROW_DEVICE_CPU, - sync_event: std::ptr::null_mut(), - }; - std::ptr::write(out, device_array); - 0 -} - -unsafe extern "C" fn exported_reader_get_last_error( - self_: *mut FFI_ArrowDeviceArrayStream, -) -> *const std::ffi::c_char { - if self_.is_null() { - return std::ptr::null(); - } - - let private = &mut *((*self_).private_data as *mut ExportedDeviceReaderPrivate); - let inner = &mut private.inner as *mut FFI_ArrowArrayStream as *mut ArrowArrayStreamInternal; - - if let Some(get_last_error) = (*inner).get_last_error { - get_last_error(inner) - } else { - std::ptr::null() - } -} - -unsafe extern "C" fn exported_reader_release(self_: *mut FFI_ArrowDeviceArrayStream) { - if self_.is_null() { - return; - } - - let stream = &mut *self_; - if stream.private_data.is_null() { - return; - } - - // Drop the private data (which will drop the inner FFI_ArrowArrayStream) - let _ = Box::from_raw(stream.private_data as *mut ExportedDeviceReaderPrivate); - stream.private_data = std::ptr::null_mut(); - stream.release = None; -} - -/// Internal representation of FFI_ArrowArrayStream with public fields. -/// Duplicated here because arrow-rs doesn't expose the fields publicly. -#[repr(C)] -struct ArrowArrayStreamInternal { - get_schema: Option< - unsafe extern "C" fn( - self_: *mut ArrowArrayStreamInternal, - out: *mut FFI_ArrowSchema, - ) -> std::ffi::c_int, - >, - get_next: Option< - unsafe extern "C" fn( - self_: *mut ArrowArrayStreamInternal, - out: *mut FFI_ArrowArray, - ) -> std::ffi::c_int, - >, - get_last_error: Option< - unsafe extern "C" fn(self_: *mut ArrowArrayStreamInternal) -> *const std::ffi::c_char, - >, - release: Option, - private_data: *mut std::ffi::c_void, -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::{Int32Array, RecordBatchIterator, StringArray}; - use arrow_schema::Field; - - fn make_test_batches() -> (SchemaRef, Vec) { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])), - ], - ) - .unwrap(); - - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![4, 5])), - Arc::new(StringArray::from(vec![Some("d"), Some("e")])), - ], - ) - .unwrap(); - - (schema, vec![batch1, batch2]) - } - - #[test] - fn test_roundtrip_two_batches() { - let (schema, batches) = make_test_batches(); - let original_batches = batches.clone(); - - // Create a RecordBatchReader - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - - // Export to FFI_ArrowArrayStream, then to FFI_ArrowDeviceArrayStream - let array_stream = FFI_ArrowArrayStream::new(Box::new(reader)); - let device_stream: FFI_ArrowDeviceArrayStream = array_stream.into(); - - // Import back via DeviceStreamReader - let imported_reader = DeviceStreamReader::try_new(device_stream).unwrap(); - - // Verify schema - assert_eq!(imported_reader.schema(), schema); - - // Collect and verify batches - let imported_batches: Vec = - imported_reader.into_iter().map(|r| r.unwrap()).collect(); - - assert_eq!(imported_batches.len(), original_batches.len()); - for (imported, original) in imported_batches.iter().zip(original_batches.iter()) { - assert_eq!(imported, original); - } - } - - /// A RecordBatchReader that yields one batch then errors - struct ErroringReader { - schema: SchemaRef, - yielded_first: bool, - first_batch: Option, - } - - impl ErroringReader { - fn new(schema: SchemaRef, first_batch: RecordBatch) -> Self { - Self { - schema, - yielded_first: false, - first_batch: Some(first_batch), - } - } - } - - impl Iterator for ErroringReader { - type Item = Result; - - fn next(&mut self) -> Option { - if !self.yielded_first { - self.yielded_first = true; - Some(Ok(self.first_batch.take().unwrap())) - } else { - Some(Err(ArrowError::ComputeError( - "intentional test error".to_string(), - ))) - } - } - } - - impl RecordBatchReader for ErroringReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - } - - #[test] - fn test_roundtrip_with_error() { - let (schema, batches) = make_test_batches(); - let first_batch = batches[0].clone(); - - // Create an erroring reader - let reader = ErroringReader::new(schema.clone(), first_batch.clone()); - - // Export to FFI_ArrowArrayStream, then to FFI_ArrowDeviceArrayStream - let array_stream = FFI_ArrowArrayStream::new(Box::new(reader)); - let device_stream: FFI_ArrowDeviceArrayStream = array_stream.into(); - - // Import back via DeviceStreamReader - let mut imported_reader = DeviceStreamReader::try_new(device_stream).unwrap(); - - // First batch should succeed - let result1 = imported_reader.next(); - assert!(result1.is_some()); - let batch1 = result1.unwrap(); - assert!(batch1.is_ok()); - assert_eq!(batch1.unwrap(), first_batch); - - // Second call should return an error - let result2 = imported_reader.next(); - assert!(result2.is_some()); - let batch2 = result2.unwrap(); - assert!(batch2.is_err()); - let err = batch2.unwrap_err(); - assert!(err.to_string().contains("intentional test error")); - } -} diff --git a/c/sedona-extension/src/extension_ffi.rs b/c/sedona-extension/src/extension_ffi.rs index 63dd8a9e51..c381ac34f8 100644 --- a/c/sedona-extension/src/extension_ffi.rs +++ b/c/sedona-extension/src/extension_ffi.rs @@ -205,6 +205,42 @@ impl Drop for FFI_ArrowDeviceArrayStream { } } +/// FFI representation of the ArrowAsyncProducer from the Arrow C Device Data Interface. +/// +/// This producer-managed object allows consumers to control flow via back-pressure +/// (`request`) and cancellation (`cancel`). +/// +/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html +#[repr(C)] +pub struct FFI_ArrowAsyncProducer { + /// The device type that this producer will provide data on. + pub device_type: i32, + + /// Request `n` additional arrays/batches. The producer should only call + /// `on_next_task` up to the total sum of requested batches. + pub request: Option, + + /// Signal that the producer should stop producing. Idempotent and thread-safe. + pub cancel: Option, + + /// Release callback for cleanup. + pub release: Option, + + /// Optional metadata string (same encoding as ArrowSchema.metadata). + pub additional_metadata: *const c_char, + + /// Opaque producer-specific data. + pub private_data: *mut c_void, +} + +impl Drop for FFI_ArrowAsyncProducer { + fn drop(&mut self) { + if let Some(releaser) = self.release { + unsafe { releaser(self) }; + } + } +} + /// FFI representation of the ArrowAsyncDeviceStreamHandler from the Arrow C Device Data Interface /// /// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html @@ -232,6 +268,11 @@ pub struct FFI_ArrowAsyncDeviceStreamHandler { ), >, pub release: Option, + + /// The producer object that the consumer uses to request data or cancel. + /// Must be populated by the producer before calling `on_schema`. + pub producer: *mut FFI_ArrowAsyncProducer, + pub private_data: *mut c_void, } diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index 37b86c847b..4b103fe622 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub mod device_stream_reader; pub mod extension_ffi; pub mod scalar_kernel; diff --git a/c/sedona-extension/src/sendable_record_batch_stream.rs b/c/sedona-extension/src/sendable_record_batch_stream.rs index fd604b0cf3..19dc9b1f62 100644 --- a/c/sedona-extension/src/sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/sendable_record_batch_stream.rs @@ -19,26 +19,102 @@ //! //! This module provides interoperability between DataFusion's async record batch streams //! and C code via the [`FFI_ArrowAsyncDeviceStreamHandler`] interface. +//! +//! The implementation follows the Arrow Async Device Stream specification, including: +//! - Back-pressure via `request()` callback +//! - Consumer-initiated cancellation via `cancel()` callback +//! - Proper end-of-stream signaling use std::ffi::{c_int, c_void, CString}; use std::ptr::null_mut; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::task::Poll; -use arrow_array::ffi::FFI_ArrowSchema; +use arrow_array::ffi::{to_ffi, FFI_ArrowSchema}; use arrow_array::RecordBatch; use futures::StreamExt; -use datafusion_execution::SendableRecordBatchStream; +pub use datafusion_execution::SendableRecordBatchStream; use crate::extension_ffi::{ - FFI_ArrowAsyncDeviceStreamHandler, FFI_ArrowAsyncTask, FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, + FFI_ArrowAsyncDeviceStreamHandler, FFI_ArrowAsyncProducer, FFI_ArrowAsyncTask, + FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, }; +/// Yields control back to the async executor once, then resumes. +async fn yield_once() { + let mut yielded = false; + futures::future::poll_fn(|cx| { + if yielded { + Poll::Ready(()) + } else { + yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + .await +} + +/// Shared state between the async driver and the FFI producer callbacks. +struct ProducerState { + /// Number of batches requested by consumer (back-pressure). + requested: AtomicU64, + /// Set to true when consumer calls cancel(). + cancelled: AtomicBool, +} + +impl ProducerState { + fn new() -> Self { + Self { + requested: AtomicU64::new(0), + cancelled: AtomicBool::new(false), + } + } + + fn request(&self, n: u64) { + self.requested.fetch_add(n, Ordering::SeqCst); + } + + fn cancel(&self) { + self.cancelled.store(true, Ordering::SeqCst); + } + + fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::SeqCst) + } + + /// Try to consume one request slot. Returns true if successful. + fn try_consume_request(&self) -> bool { + loop { + let current = self.requested.load(Ordering::SeqCst); + if current == 0 { + return false; + } + if self + .requested + .compare_exchange(current, current - 1, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + return true; + } + } + } + + fn has_requests(&self) -> bool { + self.requested.load(Ordering::SeqCst) > 0 + } +} + /// Drives a [`SendableRecordBatchStream`] and pushes results to an [`FFI_ArrowAsyncDeviceStreamHandler`]. /// -/// This function consumes the stream and calls the handler's callbacks: -/// - `on_schema` is called first with the stream's schema -/// - `on_next_task` is called for each record batch -/// - `on_error` is called if an error occurs +/// This function implements the full Arrow Async Device Stream producer protocol: +/// - Creates an `FFI_ArrowAsyncProducer` and sets it on the handler before calling `on_schema` +/// - Respects back-pressure: waits for consumer to call `producer.request(n)` before sending batches +/// - Handles consumer cancellation via `producer.cancel()` +/// - Signals end-of-stream by calling `on_next_task` with `NULL` +/// - Calls `on_error` if an error occurs (including cancellation) /// /// # Safety /// @@ -50,58 +126,211 @@ use crate::extension_ffi::{ /// ```ignore /// use sedona_extension::sendable_record_batch_stream::drive_stream_to_handler; /// -/// // Assuming you have a SendableRecordBatchStream and a handler +/// // The consumer should call handler.producer.request(n) to receive batches /// drive_stream_to_handler(stream, handler).await; /// ``` pub async fn drive_stream_to_handler( - mut stream: SendableRecordBatchStream, + stream: SendableRecordBatchStream, handler: *mut FFI_ArrowAsyncDeviceStreamHandler, ) { if handler.is_null() { return; } + // Create shared state for producer callbacks + let state = Arc::new(ProducerState::new()); + + // Create the producer + let mut producer = create_producer(Arc::clone(&state)); + + // Use a guard to handle unexpected drops + let mut guard = StreamDriverGuard::new(handler); + + let result = drive_stream_inner(stream, handler, &mut producer, state).await; + + if result.is_ok() { + guard.disarm(); + } +} + +/// Inner implementation that returns a Result for cleaner control flow. +async fn drive_stream_inner( + mut stream: SendableRecordBatchStream, + handler: *mut FFI_ArrowAsyncDeviceStreamHandler, + producer: &mut FFI_ArrowAsyncProducer, + state: Arc, +) -> Result<(), ()> { let handler_ref = unsafe { &mut *handler }; - // First, send the schema + // Set the producer on the handler BEFORE calling on_schema (per Arrow spec) + handler_ref.producer = producer; + + // Send the schema let schema = stream.schema(); let ffi_schema = match FFI_ArrowSchema::try_from(schema.as_ref()) { Ok(s) => s, Err(e) => { call_on_error(handler_ref, 1, &e.to_string()); - return; + return Err(()); } }; if let Some(on_schema) = handler_ref.on_schema { - // We need to box the schema so it lives long enough let mut boxed_schema = Box::new(ffi_schema); let ret = unsafe { on_schema(handler, boxed_schema.as_mut()) }; if ret != 0 { call_on_error(handler_ref, ret, "on_schema callback failed"); - return; + return Err(()); } - // The handler now owns the schema, so we forget it + // The handler now owns the schema std::mem::forget(boxed_schema); } - // Stream batches - while let Some(result) = stream.next().await { - match result { - Ok(batch) => { + // Stream batches with back-pressure + loop { + // Check for cancellation + if state.is_cancelled() { + // Per Arrow spec: successful cancel should NOT call on_error, + // just signal end of stream and release + signal_end_of_stream(handler_ref, handler); + return Ok(()); + } + + // Wait for consumer to request batches (back-pressure) + // Yield to allow other tasks to run while waiting + while !state.has_requests() && !state.is_cancelled() { + yield_once().await; + } + + // Re-check cancellation after waiting + if state.is_cancelled() { + signal_end_of_stream(handler_ref, handler); + return Ok(()); + } + + // Consume one request slot + if !state.try_consume_request() { + continue; + } + + // Get next batch + match stream.next().await { + Some(Ok(batch)) => { if let Err(e) = send_batch_to_handler(handler_ref, handler, batch) { call_on_error(handler_ref, 1, &e); - return; + return Err(()); } } - Err(e) => { + Some(Err(e)) => { call_on_error(handler_ref, 1, &e.to_string()); - return; + return Err(()); + } + None => { + // End of stream + signal_end_of_stream(handler_ref, handler); + return Ok(()); } } } } +/// Create an FFI producer with request/cancel callbacks. +fn create_producer(state: Arc) -> FFI_ArrowAsyncProducer { + // Convert Arc to raw pointer - this increments the refcount + let private_data = Arc::into_raw(state) as *mut c_void; + + FFI_ArrowAsyncProducer { + device_type: ARROW_DEVICE_CPU, + request: Some(producer_request), + cancel: Some(producer_cancel), + release: Some(producer_release), + additional_metadata: std::ptr::null(), + private_data, + } +} + +unsafe extern "C" fn producer_request(self_: *mut FFI_ArrowAsyncProducer, n: u64) { + if self_.is_null() { + return; + } + let producer = &*self_; + if producer.private_data.is_null() { + return; + } + let state = &*(producer.private_data as *const ProducerState); + state.request(n); +} + +unsafe extern "C" fn producer_cancel(self_: *mut FFI_ArrowAsyncProducer) { + if self_.is_null() { + return; + } + let producer = &*self_; + if producer.private_data.is_null() { + return; + } + let state = &*(producer.private_data as *const ProducerState); + state.cancel(); +} + +unsafe extern "C" fn producer_release(self_: *mut FFI_ArrowAsyncProducer) { + if self_.is_null() { + return; + } + let producer = &mut *self_; + if producer.private_data.is_null() { + return; + } + // Drop the Arc + let _ = Arc::from_raw(producer.private_data as *const ProducerState); + producer.private_data = null_mut(); +} + +/// Guard that calls `on_error` if dropped without being disarmed. +struct StreamDriverGuard { + handler: *mut FFI_ArrowAsyncDeviceStreamHandler, + armed: bool, +} + +impl StreamDriverGuard { + fn new(handler: *mut FFI_ArrowAsyncDeviceStreamHandler) -> Self { + Self { + handler, + armed: true, + } + } + + fn disarm(&mut self) { + self.armed = false; + } +} + +impl Drop for StreamDriverGuard { + fn drop(&mut self) { + if self.armed && !self.handler.is_null() { + let handler_ref = unsafe { &mut *self.handler }; + call_on_error( + handler_ref, + libc::ECANCELED, + "Stream driver dropped unexpectedly", + ); + } + } +} + +/// Signal end of stream by calling on_next_task with NULL. +fn signal_end_of_stream( + handler_ref: &mut FFI_ArrowAsyncDeviceStreamHandler, + handler: *mut FFI_ArrowAsyncDeviceStreamHandler, +) { + if let Some(on_next_task) = handler_ref.on_next_task { + // Per Arrow spec: pass NULL task pointer to signal end of stream + unsafe { + on_next_task(handler, std::ptr::null_mut(), std::ptr::null()); + } + } +} + fn send_batch_to_handler( handler_ref: &mut FFI_ArrowAsyncDeviceStreamHandler, handler: *mut FFI_ArrowAsyncDeviceStreamHandler, @@ -175,19 +404,13 @@ unsafe extern "C" fn async_task_extract_data( // Convert to struct array and then to FFI let struct_array: arrow_array::StructArray = batch.into(); - let (ffi_array, _ffi_schema) = match arrow_array::ffi::to_ffi(&struct_array.into()) { + let (ffi_array, _ffi_schema) = match to_ffi(&struct_array.into()) { Ok(result) => result, Err(_) => return 1, }; // Create device array (CPU) - let device_array = FFI_ArrowDeviceArray { - array: ffi_array, - device_id: -1, - device_type: ARROW_DEVICE_CPU, - sync_event: null_mut(), - }; - + let device_array = FFI_ArrowDeviceArray::from(ffi_array); std::ptr::write(out, device_array); 0 } From fd7abf42dcdc78de296a34a2041ff7de53d806f8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 12:46:52 -0500 Subject: [PATCH 06/16] rename --- ...d_batch_stream.rs => export_sendable_record_batch_stream.rs} | 0 c/sedona-extension/src/lib.rs | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename c/sedona-extension/src/{sendable_record_batch_stream.rs => export_sendable_record_batch_stream.rs} (100%) diff --git a/c/sedona-extension/src/sendable_record_batch_stream.rs b/c/sedona-extension/src/export_sendable_record_batch_stream.rs similarity index 100% rename from c/sedona-extension/src/sendable_record_batch_stream.rs rename to c/sedona-extension/src/export_sendable_record_batch_stream.rs diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index 4b103fe622..811d3f2d57 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -19,4 +19,4 @@ pub mod extension_ffi; pub mod scalar_kernel; #[cfg(feature = "async")] -pub mod sendable_record_batch_stream; +pub mod export_sendable_record_batch_stream; From 86105a9dcd8fd6c69a9c9f28ceeaafb04345df1e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 13:05:18 -0500 Subject: [PATCH 07/16] generate importable --- .../import_sendable_record_batch_stream.rs | 444 ++++++++++++++++++ c/sedona-extension/src/lib.rs | 3 + 2 files changed, 447 insertions(+) create mode 100644 c/sedona-extension/src/import_sendable_record_batch_stream.rs diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs new file mode 100644 index 0000000000..9080a181d4 --- /dev/null +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -0,0 +1,444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Import an async device stream from FFI into a [`SendableRecordBatchStream`]. +//! +//! This module allows Rust code to consume async record batch streams provided by +//! C/FFI producers via the Arrow C Device Data Interface. +//! +//! # Usage +//! +//! 1. Create an [`ImportedAsyncDeviceStream`] which provides an [`FFI_ArrowAsyncDeviceStreamHandler`] +//! 2. Pass the handler pointer to the FFI producer +//! 3. Use the stream as a normal [`SendableRecordBatchStream`] +//! +//! ```ignore +//! use sedona_extension::import_sendable_record_batch_stream::ImportedAsyncDeviceStream; +//! +//! let (stream, handler_ptr) = ImportedAsyncDeviceStream::new(16); +//! +//! // Pass handler_ptr to FFI producer... +//! // ffi_producer_start(handler_ptr); +//! +//! // Consume the stream +//! while let Some(batch) = stream.next().await { +//! let batch = batch?; +//! // process batch +//! } +//! ``` + +use std::ffi::{c_int, c_void, CStr}; +use std::pin::Pin; +use std::ptr::null_mut; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; + +use arrow_array::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow_array::{RecordBatch, StructArray}; +use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; +use datafusion_common::Result; +use datafusion_execution::RecordBatchStream; +use futures::Stream; + +use crate::extension_ffi::{ + FFI_ArrowAsyncDeviceStreamHandler, FFI_ArrowAsyncProducer, FFI_ArrowAsyncTask, + FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, +}; + +/// Messages sent from FFI callbacks to the stream. +enum StreamMessage { + /// Schema received from producer. + Schema(SchemaRef), + /// A record batch task received. + Task(FFI_ArrowAsyncTask), + /// End of stream (NULL task received). + EndOfStream, + /// Error from producer. + Error(ArrowError), +} + +/// Shared state between the handler callbacks and the stream. +struct ImportedStreamState { + /// The schema, set by on_schema callback. + schema: Option, + /// Pending messages from callbacks. + messages: Vec, + /// Waker to wake the stream when new data arrives. + waker: Option, + /// Whether the stream has ended. + ended: bool, + /// Producer pointer for requesting more data. + producer: *mut FFI_ArrowAsyncProducer, + /// Number of batches to request at a time (for back-pressure). + prefetch_count: u64, + /// Number of outstanding requests. + pending_requests: u64, +} + +// Safety: We ensure proper synchronization via Mutex +unsafe impl Send for ImportedStreamState {} + +impl ImportedStreamState { + fn new(prefetch_count: u64) -> Self { + Self { + schema: None, + messages: Vec::new(), + waker: None, + ended: false, + producer: null_mut(), + prefetch_count, + pending_requests: 0, + } + } + + fn wake(&self) { + if let Some(ref waker) = self.waker { + waker.wake_by_ref(); + } + } + + /// Request more data from the producer if needed. + fn maybe_request_more(&mut self) { + if self.ended || self.producer.is_null() { + return; + } + + // Request more when we're running low + if self.pending_requests < self.prefetch_count / 2 { + let to_request = self.prefetch_count - self.pending_requests; + if let Some(request_fn) = unsafe { (*self.producer).request } { + unsafe { request_fn(self.producer, to_request) }; + self.pending_requests += to_request; + } + } + } +} + +/// A [`RecordBatchStream`] that consumes data from an FFI async producer. +/// +/// Create with [`ImportedAsyncDeviceStream::new`], then pass the handler pointer +/// to the FFI producer. The stream will yield batches as the producer sends them. +pub struct ImportedAsyncDeviceStream { + state: Arc>, + /// The handler - kept alive for the lifetime of the stream. + /// Box is used to get a stable address. + _handler: Box, + /// Cached schema for the RecordBatchStream trait. + schema: Option, + /// DataType for converting FFI arrays. + schema_struct_type: Option, +} + +// Safety: The handler contains raw pointers but they are only accessed from +// the thread that polls the stream. The Arc> provides synchronization +// for the shared state accessed by callbacks. +unsafe impl Send for ImportedAsyncDeviceStream {} + +impl ImportedAsyncDeviceStream { + /// Create a new imported stream. + /// + /// Returns the stream and a pointer to the handler that should be passed to the FFI producer. + /// + /// # Arguments + /// + /// * `prefetch_count` - Number of batches to request ahead for back-pressure. + /// A larger value reduces latency but uses more memory. + /// + /// # Safety + /// + /// The returned handler pointer is valid for the lifetime of the stream. + /// The FFI producer must not use the handler after the stream is dropped. + pub fn new(prefetch_count: u64) -> (Self, *mut FFI_ArrowAsyncDeviceStreamHandler) { + let state = Arc::new(Mutex::new(ImportedStreamState::new(prefetch_count))); + + let handler = Box::new(FFI_ArrowAsyncDeviceStreamHandler { + on_schema: Some(handler_on_schema), + on_next_task: Some(handler_on_next_task), + on_error: Some(handler_on_error), + release: Some(handler_release), + producer: null_mut(), + private_data: Arc::into_raw(Arc::clone(&state)) as *mut c_void, + }); + + let handler_ptr = handler.as_ref() as *const _ as *mut FFI_ArrowAsyncDeviceStreamHandler; + + let stream = Self { + state, + _handler: handler, + schema: None, + schema_struct_type: None, + }; + + (stream, handler_ptr) + } + + /// Convert this stream into a [`SendableRecordBatchStream`]. + /// + /// This is a convenience method equivalent to `Box::pin(stream)`. + pub fn into_sendable(self) -> datafusion_execution::SendableRecordBatchStream { + Box::pin(self) + } + + /// Cancel the stream, signaling to the producer to stop sending data. + pub fn cancel(&self) { + let state = self.state.lock().unwrap(); + if !state.producer.is_null() { + if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { + unsafe { cancel_fn(state.producer) }; + } + } + } + + /// Poll for the next message, handling schema initialization. + fn poll_next_inner(&mut self, cx: &mut Context<'_>) -> Poll>> { + let mut state = self.state.lock().unwrap(); + + // Register waker for when new data arrives + state.waker = Some(cx.waker().clone()); + + // Process any pending messages + if let Some(msg) = state.messages.pop() { + match msg { + StreamMessage::Schema(schema) => { + self.schema_struct_type = + Some(DataType::Struct(schema.fields().iter().cloned().collect())); + self.schema = Some(schema.clone()); + state.schema = Some(schema); + + // Request initial batch of data + state.maybe_request_more(); + + // Continue polling for actual data + drop(state); + return self.poll_next_inner(cx); + } + StreamMessage::Task(mut task) => { + state.pending_requests = state.pending_requests.saturating_sub(1); + state.maybe_request_more(); + drop(state); + + // Extract data from the task + let batch = self.extract_batch_from_task(&mut task); + return Poll::Ready(Some(batch)); + } + StreamMessage::EndOfStream => { + state.ended = true; + return Poll::Ready(None); + } + StreamMessage::Error(e) => { + state.ended = true; + return Poll::Ready(Some(Err(e.into()))); + } + } + } + + // Check if stream has ended + if state.ended { + return Poll::Ready(None); + } + + // No messages available, wait for callback + Poll::Pending + } + + /// Extract a RecordBatch from an FFI task. + fn extract_batch_from_task(&self, task: &mut FFI_ArrowAsyncTask) -> Result { + let Some(extract_data) = task.extract_data else { + return Err(ArrowError::CDataInterface("extract_data is null".to_string()).into()); + }; + + let mut device_array = FFI_ArrowDeviceArray { + array: FFI_ArrowArray::empty(), + device_id: 0, + device_type: 0, + sync_event: null_mut(), + }; + + let ret = unsafe { extract_data(task, &mut device_array) }; + if ret != 0 { + return Err(ArrowError::CDataInterface("extract_data failed".to_string()).into()); + } + + // Only CPU device supported + if device_array.device_type != ARROW_DEVICE_CPU { + return Err(ArrowError::CDataInterface(format!( + "Unsupported device type: {}", + device_array.device_type + )) + .into()); + } + + // Convert to RecordBatch + let Some(ref struct_type) = self.schema_struct_type else { + return Err(ArrowError::CDataInterface("Schema not yet received".to_string()).into()); + }; + + let array_data = + unsafe { from_ffi_and_data_type(device_array.array, struct_type.clone())? }; + let struct_array: StructArray = array_data.into(); + Ok(struct_array.into()) + } +} + +impl Stream for ImportedAsyncDeviceStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_inner(cx) + } +} + +impl RecordBatchStream for ImportedAsyncDeviceStream { + fn schema(&self) -> SchemaRef { + self.schema + .clone() + .unwrap_or_else(|| Arc::new(Schema::empty())) + } +} + +// FFI callback implementations + +unsafe extern "C" fn handler_on_schema( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + schema: *mut FFI_ArrowSchema, +) -> c_int { + if self_.is_null() || schema.is_null() { + return 1; + } + + let handler = &mut *self_; + if handler.private_data.is_null() { + return 1; + } + + // Store the producer pointer for later use + let state_ptr = handler.private_data as *const Mutex; + let state_arc = Arc::from_raw(state_ptr); + + let result = { + let mut state = state_arc.lock().unwrap(); + state.producer = handler.producer; + + // Import the schema + let ffi_schema = std::ptr::read(schema); + match Schema::try_from(&ffi_schema) { + Ok(s) => { + state.messages.push(StreamMessage::Schema(Arc::new(s))); + state.wake(); + 0 + } + Err(e) => { + state.messages.push(StreamMessage::Error(e)); + state.wake(); + 1 + } + } + }; + + // Don't drop the Arc, just forget about this reference + let _ = Arc::into_raw(state_arc); + result +} + +unsafe extern "C" fn handler_on_next_task( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + task: *mut FFI_ArrowAsyncTask, + _metadata: *const std::ffi::c_char, +) -> c_int { + if self_.is_null() { + return 1; + } + + let handler = &*self_; + if handler.private_data.is_null() { + return 1; + } + + let state_ptr = handler.private_data as *const Mutex; + let state_arc = Arc::from_raw(state_ptr); + + { + let mut state = state_arc.lock().unwrap(); + + if task.is_null() { + // NULL task signals end of stream + state.messages.push(StreamMessage::EndOfStream); + } else { + // Take ownership of the task by copying it + let task_copy = std::ptr::read(task); + state.messages.push(StreamMessage::Task(task_copy)); + } + state.wake(); + } + + let _ = Arc::into_raw(state_arc); + 0 +} + +unsafe extern "C" fn handler_on_error( + self_: *mut FFI_ArrowAsyncDeviceStreamHandler, + code: c_int, + message: *const std::ffi::c_char, + _metadata: *const std::ffi::c_char, +) { + if self_.is_null() { + return; + } + + let handler = &*self_; + if handler.private_data.is_null() { + return; + } + + let state_ptr = handler.private_data as *const Mutex; + let state_arc = Arc::from_raw(state_ptr); + + { + let mut state = state_arc.lock().unwrap(); + + let error_msg = if message.is_null() { + format!("FFI error code {}", code) + } else { + let c_str = CStr::from_ptr(message); + c_str.to_string_lossy().into_owned() + }; + + state + .messages + .push(StreamMessage::Error(ArrowError::CDataInterface(error_msg))); + state.ended = true; + state.wake(); + } + + let _ = Arc::into_raw(state_arc); +} + +unsafe extern "C" fn handler_release(self_: *mut FFI_ArrowAsyncDeviceStreamHandler) { + if self_.is_null() { + return; + } + + let handler = &mut *self_; + if handler.private_data.is_null() { + return; + } + + // Drop our Arc reference + let state_ptr = handler.private_data as *const Mutex; + let _ = Arc::from_raw(state_ptr); + handler.private_data = null_mut(); +} diff --git a/c/sedona-extension/src/lib.rs b/c/sedona-extension/src/lib.rs index 811d3f2d57..1e7f324df8 100644 --- a/c/sedona-extension/src/lib.rs +++ b/c/sedona-extension/src/lib.rs @@ -20,3 +20,6 @@ pub mod scalar_kernel; #[cfg(feature = "async")] pub mod export_sendable_record_batch_stream; + +#[cfg(feature = "async")] +pub mod import_sendable_record_batch_stream; From 9c872cca201e4b36fa216ecf7d5c45bfd0d3dd5f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 13:24:37 -0500 Subject: [PATCH 08/16] first round of tests --- Cargo.lock | 1 + c/sedona-extension/Cargo.toml | 6 +- .../export_sendable_record_batch_stream.rs | 46 +-- .../import_sendable_record_batch_stream.rs | 267 ++++++++++++++++++ 4 files changed, 299 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1cd934ecd0..b7579e5faa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5790,6 +5790,7 @@ dependencies = [ "sedona-expr", "sedona-schema", "sedona-testing", + "tokio", ] [[package]] diff --git a/c/sedona-extension/Cargo.toml b/c/sedona-extension/Cargo.toml index 3d27389288..5125e895b5 100644 --- a/c/sedona-extension/Cargo.toml +++ b/c/sedona-extension/Cargo.toml @@ -29,7 +29,7 @@ rust-version.workspace = true [features] default = [] -async = ["dep:datafusion-execution", "dep:futures"] +async = ["dep:datafusion-execution", "dep:futures", "dep:tokio"] [dependencies] arrow-array = { workspace = true, features = ["ffi"]} @@ -43,3 +43,7 @@ sedona-common = { workspace = true } sedona-expr = { workspace = true } sedona-schema = { workspace = true } sedona-testing = { path = "../../rust/sedona-testing" } +tokio = { workspace = true, optional = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/c/sedona-extension/src/export_sendable_record_batch_stream.rs b/c/sedona-extension/src/export_sendable_record_batch_stream.rs index 19dc9b1f62..a0e7cf2eb9 100644 --- a/c/sedona-extension/src/export_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/export_sendable_record_batch_stream.rs @@ -29,7 +29,7 @@ use std::ffi::{c_int, c_void, CString}; use std::ptr::null_mut; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; -use std::task::Poll; +use std::task::{Poll, Waker}; use arrow_array::ffi::{to_ffi, FFI_ArrowSchema}; use arrow_array::RecordBatch; @@ -42,27 +42,14 @@ use crate::extension_ffi::{ FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, }; -/// Yields control back to the async executor once, then resumes. -async fn yield_once() { - let mut yielded = false; - futures::future::poll_fn(|cx| { - if yielded { - Poll::Ready(()) - } else { - yielded = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - }) - .await -} - /// Shared state between the async driver and the FFI producer callbacks. struct ProducerState { /// Number of batches requested by consumer (back-pressure). requested: AtomicU64, /// Set to true when consumer calls cancel(). cancelled: AtomicBool, + /// Waker to wake when requests are available or cancelled. + waker: std::sync::Mutex>, } impl ProducerState { @@ -70,15 +57,18 @@ impl ProducerState { Self { requested: AtomicU64::new(0), cancelled: AtomicBool::new(false), + waker: std::sync::Mutex::new(None), } } fn request(&self, n: u64) { self.requested.fetch_add(n, Ordering::SeqCst); + self.wake(); } fn cancel(&self) { self.cancelled.store(true, Ordering::SeqCst); + self.wake(); } fn is_cancelled(&self) -> bool { @@ -105,6 +95,16 @@ impl ProducerState { fn has_requests(&self) -> bool { self.requested.load(Ordering::SeqCst) > 0 } + + fn register_waker(&self, waker: Waker) { + *self.waker.lock().unwrap() = Some(waker); + } + + fn wake(&self) { + if let Some(waker) = self.waker.lock().unwrap().take() { + waker.wake(); + } + } } /// Drives a [`SendableRecordBatchStream`] and pushes results to an [`FFI_ArrowAsyncDeviceStreamHandler`]. @@ -197,10 +197,16 @@ async fn drive_stream_inner( } // Wait for consumer to request batches (back-pressure) - // Yield to allow other tasks to run while waiting - while !state.has_requests() && !state.is_cancelled() { - yield_once().await; - } + // Use poll_fn to properly register our waker with ProducerState + futures::future::poll_fn(|cx| { + if state.has_requests() || state.is_cancelled() { + Poll::Ready(()) + } else { + state.register_waker(cx.waker().clone()); + Poll::Pending + } + }) + .await; // Re-check cancellation after waiting if state.is_cancelled() { diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index 9080a181d4..ac0729ecd0 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -442,3 +442,270 @@ unsafe extern "C" fn handler_release(self_: *mut FFI_ArrowAsyncDeviceStreamHandl let _ = Arc::from_raw(state_ptr); handler.private_data = null_mut(); } + +#[cfg(test)] +mod tests { + use super::*; + use crate::export_sendable_record_batch_stream::drive_stream_to_handler; + use arrow_array::{Int32Array, StringArray}; + use arrow_schema::Field; + use datafusion_execution::RecordBatchStream; + use futures::StreamExt; + use std::pin::Pin; + use std::sync::Arc; + + /// Create a test schema with id and name columns. + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])) + } + + /// Create a test batch with the given row count. + fn make_batch(schema: &SchemaRef, start_id: i32, count: i32) -> RecordBatch { + let ids: Vec = (start_id..start_id + count).collect(); + let names: Vec> = ids + .iter() + .map(|i| Some(if i % 2 == 0 { "even" } else { "odd" })) + .collect(); + + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(StringArray::from(names)), + ], + ) + .unwrap() + } + + /// A simple SendableRecordBatchStream that yields once before each batch. + /// This ensures proper async interleaving between producer and consumer. + struct TestStream { + schema: SchemaRef, + batches: std::collections::VecDeque>, + yielded: bool, + } + + impl TestStream { + fn new(schema: SchemaRef, batches: Vec) -> Pin> { + Box::pin(Self { + schema, + batches: batches.into_iter().map(Ok).collect(), + yielded: false, // Start not yielded so first poll yields + }) + } + + fn with_error( + schema: SchemaRef, + batches: Vec, + error_after: usize, + ) -> Pin> { + let mut results: std::collections::VecDeque> = + batches.into_iter().map(Ok).collect(); + if error_after < results.len() { + results.insert( + error_after, + Err(datafusion_common::DataFusionError::External( + "Test error".into(), + )), + ); + } else { + results.push_back(Err(datafusion_common::DataFusionError::External( + "Test error".into(), + ))); + } + Box::pin(Self { + schema, + batches: results, + yielded: false, // Start not yielded so first poll yields + }) + } + } + + impl Stream for TestStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.as_mut().get_mut(); + if this.yielded { + // After yielding, return data on next poll + this.yielded = false; + if let Some(batch) = this.batches.pop_front() { + Poll::Ready(Some(batch)) + } else { + Poll::Ready(None) + } + } else { + // Haven't yielded yet - yield once to give other tasks a chance + this.yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + impl RecordBatchStream for TestStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + } + + #[tokio::test] + async fn test_empty_stream() { + let schema = test_schema(); + let source_stream = TestStream::new(schema.clone(), vec![]); + + let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + + let consumer_future = async { + let mut received = vec![]; + while let Some(result) = consumer.next().await { + received.push(result); + } + (received, consumer.schema()) + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + let ((received, result_schema), _) = futures::join!(consumer_future, producer_future); + + assert!(received.is_empty()); + assert_eq!(result_schema.fields().len(), 2); + } + + #[tokio::test] + async fn test_single_batch() { + let schema = test_schema(); + let batch = make_batch(&schema, 1, 5); + let source_stream = TestStream::new(schema.clone(), vec![batch.clone()]); + + let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + + let consumer_future = async { + let mut received = vec![]; + while let Some(result) = consumer.next().await { + received.push(result); + } + received + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + let (received, _) = futures::join!(consumer_future, producer_future); + + assert_eq!(received.len(), 1); + let received_batch = received.into_iter().next().unwrap().unwrap(); + assert_eq!(received_batch.num_rows(), 5); + assert_eq!(received_batch.schema(), batch.schema()); + + // Verify data + let ids = received_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ids.values(), &[1, 2, 3, 4, 5]); + } + + #[tokio::test] + async fn test_multiple_batches_with_backpressure() { + let schema = test_schema(); + + // Create more batches than the prefetch count + let prefetch_count = 2u64; + let num_batches = 10; + let batches: Vec = (0..num_batches) + .map(|i| make_batch(&schema, i * 3, 3)) + .collect(); + let expected_total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + let source_stream = TestStream::new(schema.clone(), batches); + + let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(prefetch_count); + + let consumer_future = async { + let mut received = vec![]; + while let Some(result) = consumer.next().await { + received.push(result); + } + received + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + let (received, _) = futures::join!(consumer_future, producer_future); + + assert_eq!(received.len(), num_batches as usize); + let total_rows: usize = received.into_iter().map(|r| r.unwrap().num_rows()).sum(); + assert_eq!(total_rows, expected_total_rows); + } + + #[tokio::test] + async fn test_stream_error() { + let schema = test_schema(); + let batches = vec![make_batch(&schema, 1, 3), make_batch(&schema, 4, 3)]; + + // Error after first batch + let source_stream = TestStream::with_error(schema.clone(), batches, 1); + + let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + + let consumer_future = async { + let mut received = vec![]; + while let Some(result) = consumer.next().await { + received.push(result); + } + received + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + let (received, _) = futures::join!(consumer_future, producer_future); + + // First batch should succeed + assert!(!received.is_empty()); + assert!(received[0].is_ok()); + + // Should have an error somewhere + let has_error = received.iter().any(|r| r.is_err()); + assert!(has_error, "Expected an error in the stream"); + } + + #[tokio::test] + async fn test_cancellation() { + let schema = test_schema(); + + // Create a large number of batches + let batches: Vec = (0..100).map(|i| make_batch(&schema, i * 3, 3)).collect(); + + let source_stream = TestStream::new(schema.clone(), batches); + + let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(2); + + let consumer_future = async { + let mut count = 0; + while let Some(result) = consumer.next().await { + result.expect("should not error"); + count += 1; + if count >= 3 { + // Cancel after receiving 3 batches + consumer.cancel(); + break; + } + } + count + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + let (count, _) = futures::join!(consumer_future, producer_future); + + // Should have received at least 3 batches before cancellation + assert!(count >= 3); + } +} From f8a9040f4b9ecc9e2cc46abaa0cc13e2178a7dff Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 13:32:26 -0500 Subject: [PATCH 09/16] simplify testing --- .../export_sendable_record_batch_stream.rs | 19 +++++++++++++++ .../import_sendable_record_batch_stream.rs | 23 +++---------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/c/sedona-extension/src/export_sendable_record_batch_stream.rs b/c/sedona-extension/src/export_sendable_record_batch_stream.rs index a0e7cf2eb9..0dfe79121f 100644 --- a/c/sedona-extension/src/export_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/export_sendable_record_batch_stream.rs @@ -42,6 +42,22 @@ use crate::extension_ffi::{ FFI_ArrowDeviceArray, ARROW_DEVICE_CPU, }; +/// Yields control back to the async executor once, then resumes. +/// This allows other tasks (e.g., the consumer) to make progress. +async fn yield_once() { + let mut yielded = false; + futures::future::poll_fn(|cx| { + if yielded { + Poll::Ready(()) + } else { + yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + .await +} + /// Shared state between the async driver and the FFI producer callbacks. struct ProducerState { /// Number of batches requested by consumer (back-pressure). @@ -226,6 +242,9 @@ async fn drive_stream_inner( call_on_error(handler_ref, 1, &e); return Err(()); } + // Yield to allow consumer to process the batch. + // This enables single-task usage (producer and consumer in same task). + yield_once().await; } Some(Err(e)) => { call_on_error(handler_ref, 1, &e.to_string()); diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index ac0729ecd0..2667a1916a 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -480,12 +480,10 @@ mod tests { .unwrap() } - /// A simple SendableRecordBatchStream that yields once before each batch. - /// This ensures proper async interleaving between producer and consumer. + /// A simple SendableRecordBatchStream from a list of batches. struct TestStream { schema: SchemaRef, batches: std::collections::VecDeque>, - yielded: bool, } impl TestStream { @@ -493,7 +491,6 @@ mod tests { Box::pin(Self { schema, batches: batches.into_iter().map(Ok).collect(), - yielded: false, // Start not yielded so first poll yields }) } @@ -519,7 +516,6 @@ mod tests { Box::pin(Self { schema, batches: results, - yielded: false, // Start not yielded so first poll yields }) } } @@ -529,23 +525,10 @@ mod tests { fn poll_next( mut self: Pin<&mut Self>, - cx: &mut Context<'_>, + _cx: &mut Context<'_>, ) -> Poll> { let this = self.as_mut().get_mut(); - if this.yielded { - // After yielding, return data on next poll - this.yielded = false; - if let Some(batch) = this.batches.pop_front() { - Poll::Ready(Some(batch)) - } else { - Poll::Ready(None) - } - } else { - // Haven't yielded yet - yield once to give other tasks a chance - this.yielded = true; - cx.waker().wake_by_ref(); - Poll::Pending - } + Poll::Ready(this.batches.pop_front()) } } From 5eeb7c137532b643c549313acd1226d18d5b47ae Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 13:44:58 -0500 Subject: [PATCH 10/16] fix dropped stream --- .../export_sendable_record_batch_stream.rs | 23 +++-- .../import_sendable_record_batch_stream.rs | 87 ++++++++++++++++--- 2 files changed, 93 insertions(+), 17 deletions(-) diff --git a/c/sedona-extension/src/export_sendable_record_batch_stream.rs b/c/sedona-extension/src/export_sendable_record_batch_stream.rs index 0dfe79121f..b0cf7678c4 100644 --- a/c/sedona-extension/src/export_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/export_sendable_record_batch_stream.rs @@ -159,14 +159,14 @@ pub async fn drive_stream_to_handler( // Create the producer let mut producer = create_producer(Arc::clone(&state)); - // Use a guard to handle unexpected drops + // Use a guard to handle unexpected drops (panics) let mut guard = StreamDriverGuard::new(handler); - let result = drive_stream_inner(stream, handler, &mut producer, state).await; + let _result = drive_stream_inner(stream, handler, &mut producer, state).await; - if result.is_ok() { - guard.disarm(); - } + // Normal completion (success or error) - disarm the guard and release handler + guard.disarm(); + release_handler(handler); } /// Inner implementation that returns a Result for cleaner control flow. @@ -339,6 +339,8 @@ impl Drop for StreamDriverGuard { libc::ECANCELED, "Stream driver dropped unexpectedly", ); + // Still need to release the handler + release_handler(self.handler); } } } @@ -356,6 +358,17 @@ fn signal_end_of_stream( } } +/// Release the handler (call its release callback). +fn release_handler(handler: *mut FFI_ArrowAsyncDeviceStreamHandler) { + if handler.is_null() { + return; + } + let handler_ref = unsafe { &*handler }; + if let Some(release) = handler_ref.release { + unsafe { release(handler) }; + } +} + fn send_batch_to_handler( handler_ref: &mut FFI_ArrowAsyncDeviceStreamHandler, handler: *mut FFI_ArrowAsyncDeviceStreamHandler, diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index 2667a1916a..16f404ea7f 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -87,6 +87,9 @@ struct ImportedStreamState { prefetch_count: u64, /// Number of outstanding requests. pending_requests: u64, + /// Set to true when the stream is dropped but producer hasn't called release yet. + /// Callbacks should return errors when this is true. + abandoned: bool, } // Safety: We ensure proper synchronization via Mutex @@ -102,6 +105,7 @@ impl ImportedStreamState { producer: null_mut(), prefetch_count, pending_requests: 0, + abandoned: false, } } @@ -132,11 +136,14 @@ impl ImportedStreamState { /// /// Create with [`ImportedAsyncDeviceStream::new`], then pass the handler pointer /// to the FFI producer. The stream will yield batches as the producer sends them. +/// +/// # Dropping +/// +/// When the stream is dropped, it signals cancellation to the producer. The handler +/// remains valid until the producer calls its `release` callback. This allows the +/// idiomatic DataFusion pattern of dropping streams to cancel queries. pub struct ImportedAsyncDeviceStream { state: Arc>, - /// The handler - kept alive for the lifetime of the stream. - /// Box is used to get a stable address. - _handler: Box, /// Cached schema for the RecordBatchStream trait. schema: Option, /// DataType for converting FFI arrays. @@ -160,25 +167,26 @@ impl ImportedAsyncDeviceStream { /// /// # Safety /// - /// The returned handler pointer is valid for the lifetime of the stream. - /// The FFI producer must not use the handler after the stream is dropped. + /// The returned handler pointer remains valid until the producer calls its `release` + /// callback. When the stream is dropped, it signals cancellation to the producer, + /// but the handler stays valid - callbacks will return error codes to tell the + /// producer to stop and release the handler. pub fn new(prefetch_count: u64) -> (Self, *mut FFI_ArrowAsyncDeviceStreamHandler) { let state = Arc::new(Mutex::new(ImportedStreamState::new(prefetch_count))); - let handler = Box::new(FFI_ArrowAsyncDeviceStreamHandler { + // Handler is heap-allocated and will be freed when producer calls release(). + // This allows the handler to outlive the stream. + let handler_ptr = Box::into_raw(Box::new(FFI_ArrowAsyncDeviceStreamHandler { on_schema: Some(handler_on_schema), on_next_task: Some(handler_on_next_task), on_error: Some(handler_on_error), release: Some(handler_release), producer: null_mut(), private_data: Arc::into_raw(Arc::clone(&state)) as *mut c_void, - }); - - let handler_ptr = handler.as_ref() as *const _ as *mut FFI_ArrowAsyncDeviceStreamHandler; + })); let stream = Self { state, - _handler: handler, schema: None, schema_struct_type: None, }; @@ -310,6 +318,22 @@ impl RecordBatchStream for ImportedAsyncDeviceStream { } } +impl Drop for ImportedAsyncDeviceStream { + fn drop(&mut self) { + let mut state = self.state.lock().unwrap(); + + // Mark as abandoned so callbacks know to return errors + state.abandoned = true; + + // Signal cancellation to the producer + if !state.producer.is_null() { + if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { + unsafe { cancel_fn(state.producer) }; + } + } + } +} + // FFI callback implementations unsafe extern "C" fn handler_on_schema( @@ -329,6 +353,17 @@ unsafe extern "C" fn handler_on_schema( let state_ptr = handler.private_data as *const Mutex; let state_arc = Arc::from_raw(state_ptr); + let abandoned = { + let state = state_arc.lock().unwrap(); + state.abandoned + }; + + // If stream was dropped, tell producer to stop + if abandoned { + let _ = Arc::into_raw(state_arc); + return 1; // Error code signals producer to stop and release + } + let result = { let mut state = state_arc.lock().unwrap(); state.producer = handler.producer; @@ -371,6 +406,17 @@ unsafe extern "C" fn handler_on_next_task( let state_ptr = handler.private_data as *const Mutex; let state_arc = Arc::from_raw(state_ptr); + let abandoned = { + let state = state_arc.lock().unwrap(); + state.abandoned + }; + + // If stream was dropped, tell producer to stop + if abandoned { + let _ = Arc::into_raw(state_arc); + return 1; // Error code signals producer to stop and release + } + { let mut state = state_arc.lock().unwrap(); @@ -434,13 +480,30 @@ unsafe extern "C" fn handler_release(self_: *mut FFI_ArrowAsyncDeviceStreamHandl let handler = &mut *self_; if handler.private_data.is_null() { + // Already released - free the handler struct + // Clear release to prevent Drop from calling us again + handler.release = None; + let _ = Box::from_raw(self_); return; } - // Drop our Arc reference + // Get the state and clear the producer pointer so stream Drop won't use it let state_ptr = handler.private_data as *const Mutex; - let _ = Arc::from_raw(state_ptr); + let state_arc = Arc::from_raw(state_ptr); + + { + let mut state = state_arc.lock().unwrap(); + state.producer = null_mut(); + state.ended = true; + } + + // Drop the Arc reference (will be freed when stream also drops its reference) + drop(state_arc); handler.private_data = null_mut(); + + // Clear release to prevent Drop impl from calling us again, then free + handler.release = None; + let _ = Box::from_raw(self_); } #[cfg(test)] From 740aac84d7ee588510232a6bd4ccd0759526a6f3 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 13:59:14 -0500 Subject: [PATCH 11/16] try to make the implementation avoid locking --- .../import_sendable_record_batch_stream.rs | 325 ++++++++++-------- 1 file changed, 173 insertions(+), 152 deletions(-) diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index 16f404ea7f..fe041f665d 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -44,14 +44,17 @@ use std::ffi::{c_int, c_void, CStr}; use std::pin::Pin; use std::ptr::null_mut; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; use arrow_array::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; use arrow_array::{RecordBatch, StructArray}; use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; use datafusion_common::Result; use datafusion_execution::RecordBatchStream; +use futures::channel::mpsc; +use futures::task::AtomicWaker; use futures::Stream; use crate::extension_ffi::{ @@ -71,62 +74,98 @@ enum StreamMessage { Error(ArrowError), } -/// Shared state between the handler callbacks and the stream. -struct ImportedStreamState { - /// The schema, set by on_schema callback. - schema: Option, - /// Pending messages from callbacks. - messages: Vec, - /// Waker to wake the stream when new data arrives. - waker: Option, - /// Whether the stream has ended. - ended: bool, +/// State shared between callbacks and stream for producer management. +/// This is the only state that requires mutex protection. +struct ProducerState { /// Producer pointer for requesting more data. producer: *mut FFI_ArrowAsyncProducer, /// Number of batches to request at a time (for back-pressure). prefetch_count: u64, - /// Number of outstanding requests. - pending_requests: u64, +} + +// Safety: Producer pointer is only dereferenced while holding the lock +unsafe impl Send for ProducerState {} +unsafe impl Sync for ProducerState {} + +/// Shared state between the handler callbacks and the stream. +struct ImportedStreamState { + /// Sender for messages from callbacks to stream (lock-free). + sender: mpsc::UnboundedSender, + /// Waker for the stream (lock-free). + waker: AtomicWaker, + /// Whether the stream has ended. + ended: AtomicBool, + /// Number of outstanding requests (for back-pressure). + pending_requests: AtomicU64, /// Set to true when the stream is dropped but producer hasn't called release yet. - /// Callbacks should return errors when this is true. - abandoned: bool, + abandoned: AtomicBool, + /// Producer state (needs mutex for FFI calls). + producer_state: Mutex, } -// Safety: We ensure proper synchronization via Mutex +// Safety: All fields are Send+Sync (atomics, Mutex, channel sender). +// The producer pointer inside ProducerState is protected by the Mutex. unsafe impl Send for ImportedStreamState {} +unsafe impl Sync for ImportedStreamState {} impl ImportedStreamState { - fn new(prefetch_count: u64) -> Self { + fn new(prefetch_count: u64, sender: mpsc::UnboundedSender) -> Self { Self { - schema: None, - messages: Vec::new(), - waker: None, - ended: false, - producer: null_mut(), - prefetch_count, - pending_requests: 0, - abandoned: false, + sender, + waker: AtomicWaker::new(), + ended: AtomicBool::new(false), + pending_requests: AtomicU64::new(0), + abandoned: AtomicBool::new(false), + producer_state: Mutex::new(ProducerState { + producer: null_mut(), + prefetch_count, + }), } } fn wake(&self) { - if let Some(ref waker) = self.waker { - waker.wake_by_ref(); - } + self.waker.wake(); } /// Request more data from the producer if needed. - fn maybe_request_more(&mut self) { - if self.ended || self.producer.is_null() { + fn maybe_request_more(&self) { + if self.ended.load(Ordering::Acquire) { + return; + } + + let producer_state = self.producer_state.lock().unwrap(); + if producer_state.producer.is_null() { return; } + let pending = self.pending_requests.load(Ordering::Acquire); + let prefetch = producer_state.prefetch_count; + // Request more when we're running low - if self.pending_requests < self.prefetch_count / 2 { - let to_request = self.prefetch_count - self.pending_requests; - if let Some(request_fn) = unsafe { (*self.producer).request } { - unsafe { request_fn(self.producer, to_request) }; - self.pending_requests += to_request; + if pending < prefetch / 2 { + let to_request = prefetch - pending; + if let Some(request_fn) = unsafe { (*producer_state.producer).request } { + unsafe { request_fn(producer_state.producer, to_request) }; + self.pending_requests.fetch_add(to_request, Ordering::Release); + } + } + } + + fn set_producer(&self, producer: *mut FFI_ArrowAsyncProducer) { + let mut state = self.producer_state.lock().unwrap(); + state.producer = producer; + } + + fn clear_producer(&self) { + let mut state = self.producer_state.lock().unwrap(); + state.producer = null_mut(); + } + + fn cancel(&self) { + let state = self.producer_state.lock().unwrap(); + if !state.producer.is_null() { + if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { + unsafe { cancel_fn(state.producer) }; } } } @@ -142,17 +181,23 @@ impl ImportedStreamState { /// When the stream is dropped, it signals cancellation to the producer. The handler /// remains valid until the producer calls its `release` callback. This allows the /// idiomatic DataFusion pattern of dropping streams to cancel queries. +/// +/// # Performance +/// +/// Uses a lock-free channel for message passing from FFI callbacks to the stream. +/// Only acquires a mutex when requesting more batches from the producer (approximately +/// once per `prefetch_count / 2` batches). pub struct ImportedAsyncDeviceStream { - state: Arc>, + state: Arc, + /// Receiver for messages from callbacks (lock-free). + receiver: mpsc::UnboundedReceiver, /// Cached schema for the RecordBatchStream trait. schema: Option, /// DataType for converting FFI arrays. schema_struct_type: Option, } -// Safety: The handler contains raw pointers but they are only accessed from -// the thread that polls the stream. The Arc> provides synchronization -// for the shared state accessed by callbacks. +// Safety: The state uses atomic operations and mutex-protected producer access. unsafe impl Send for ImportedAsyncDeviceStream {} impl ImportedAsyncDeviceStream { @@ -172,7 +217,8 @@ impl ImportedAsyncDeviceStream { /// but the handler stays valid - callbacks will return error codes to tell the /// producer to stop and release the handler. pub fn new(prefetch_count: u64) -> (Self, *mut FFI_ArrowAsyncDeviceStreamHandler) { - let state = Arc::new(Mutex::new(ImportedStreamState::new(prefetch_count))); + let (sender, receiver) = mpsc::unbounded(); + let state = Arc::new(ImportedStreamState::new(prefetch_count, sender)); // Handler is heap-allocated and will be freed when producer calls release(). // This allows the handler to outlive the stream. @@ -187,6 +233,7 @@ impl ImportedAsyncDeviceStream { let stream = Self { state, + receiver, schema: None, schema_struct_type: None, }; @@ -203,64 +250,64 @@ impl ImportedAsyncDeviceStream { /// Cancel the stream, signaling to the producer to stop sending data. pub fn cancel(&self) { - let state = self.state.lock().unwrap(); - if !state.producer.is_null() { - if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { - unsafe { cancel_fn(state.producer) }; - } - } + self.state.cancel(); } /// Poll for the next message, handling schema initialization. fn poll_next_inner(&mut self, cx: &mut Context<'_>) -> Poll>> { - let mut state = self.state.lock().unwrap(); + // Register waker (lock-free) + self.state.waker.register(cx.waker()); - // Register waker for when new data arrives - state.waker = Some(cx.waker().clone()); - - // Process any pending messages - if let Some(msg) = state.messages.pop() { - match msg { + // Poll the channel for messages first (lock-free) + // This ensures we drain all messages before returning None due to ended flag + match Pin::new(&mut self.receiver).poll_next(cx) { + Poll::Ready(Some(msg)) => match msg { StreamMessage::Schema(schema) => { self.schema_struct_type = Some(DataType::Struct(schema.fields().iter().cloned().collect())); - self.schema = Some(schema.clone()); - state.schema = Some(schema); + self.schema = Some(schema); - // Request initial batch of data - state.maybe_request_more(); + // Request initial batch of data (acquires lock) + self.state.maybe_request_more(); // Continue polling for actual data - drop(state); - return self.poll_next_inner(cx); + self.poll_next_inner(cx) } StreamMessage::Task(mut task) => { - state.pending_requests = state.pending_requests.saturating_sub(1); - state.maybe_request_more(); - drop(state); + // Decrement pending (lock-free) + self.state + .pending_requests + .fetch_sub(1, Ordering::Release); + // Maybe request more (acquires lock only if needed) + self.state.maybe_request_more(); // Extract data from the task let batch = self.extract_batch_from_task(&mut task); - return Poll::Ready(Some(batch)); + Poll::Ready(Some(batch)) } StreamMessage::EndOfStream => { - state.ended = true; - return Poll::Ready(None); + self.state.ended.store(true, Ordering::Release); + Poll::Ready(None) } StreamMessage::Error(e) => { - state.ended = true; - return Poll::Ready(Some(Err(e.into()))); + self.state.ended.store(true, Ordering::Release); + Poll::Ready(Some(Err(e.into()))) + } + }, + Poll::Ready(None) => { + // Channel closed (shouldn't happen in normal operation) + self.state.ended.store(true, Ordering::Release); + Poll::Ready(None) + } + Poll::Pending => { + // No messages in channel - check if stream has ended (lock-free) + if self.state.ended.load(Ordering::Acquire) { + Poll::Ready(None) + } else { + Poll::Pending } } } - - // Check if stream has ended - if state.ended { - return Poll::Ready(None); - } - - // No messages available, wait for callback - Poll::Pending } /// Extract a RecordBatch from an FFI task. @@ -320,17 +367,11 @@ impl RecordBatchStream for ImportedAsyncDeviceStream { impl Drop for ImportedAsyncDeviceStream { fn drop(&mut self) { - let mut state = self.state.lock().unwrap(); + // Mark as abandoned so callbacks know to return errors (lock-free) + self.state.abandoned.store(true, Ordering::Release); - // Mark as abandoned so callbacks know to return errors - state.abandoned = true; - - // Signal cancellation to the producer - if !state.producer.is_null() { - if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { - unsafe { cancel_fn(state.producer) }; - } - } + // Signal cancellation to the producer (acquires lock) + self.state.cancel(); } } @@ -349,38 +390,32 @@ unsafe extern "C" fn handler_on_schema( return 1; } - // Store the producer pointer for later use - let state_ptr = handler.private_data as *const Mutex; + // Get shared state + let state_ptr = handler.private_data as *const ImportedStreamState; let state_arc = Arc::from_raw(state_ptr); - let abandoned = { - let state = state_arc.lock().unwrap(); - state.abandoned - }; - - // If stream was dropped, tell producer to stop - if abandoned { + // If stream was dropped, tell producer to stop (lock-free check) + if state_arc.abandoned.load(Ordering::Acquire) { let _ = Arc::into_raw(state_arc); return 1; // Error code signals producer to stop and release } - let result = { - let mut state = state_arc.lock().unwrap(); - state.producer = handler.producer; + // Store the producer pointer (acquires lock) + state_arc.set_producer(handler.producer); - // Import the schema - let ffi_schema = std::ptr::read(schema); - match Schema::try_from(&ffi_schema) { - Ok(s) => { - state.messages.push(StreamMessage::Schema(Arc::new(s))); - state.wake(); - 0 - } - Err(e) => { - state.messages.push(StreamMessage::Error(e)); - state.wake(); - 1 - } + // Import the schema + let ffi_schema = std::ptr::read(schema); + let result = match Schema::try_from(&ffi_schema) { + Ok(s) => { + // Send through channel (lock-free) + let _ = state_arc.sender.unbounded_send(StreamMessage::Schema(Arc::new(s))); + state_arc.wake(); + 0 + } + Err(e) => { + let _ = state_arc.sender.unbounded_send(StreamMessage::Error(e)); + state_arc.wake(); + 1 } }; @@ -403,33 +438,25 @@ unsafe extern "C" fn handler_on_next_task( return 1; } - let state_ptr = handler.private_data as *const Mutex; + let state_ptr = handler.private_data as *const ImportedStreamState; let state_arc = Arc::from_raw(state_ptr); - let abandoned = { - let state = state_arc.lock().unwrap(); - state.abandoned - }; - - // If stream was dropped, tell producer to stop - if abandoned { + // If stream was dropped, tell producer to stop (lock-free check) + if state_arc.abandoned.load(Ordering::Acquire) { let _ = Arc::into_raw(state_arc); return 1; // Error code signals producer to stop and release } - { - let mut state = state_arc.lock().unwrap(); - - if task.is_null() { - // NULL task signals end of stream - state.messages.push(StreamMessage::EndOfStream); - } else { - // Take ownership of the task by copying it - let task_copy = std::ptr::read(task); - state.messages.push(StreamMessage::Task(task_copy)); - } - state.wake(); + // Send message through channel (lock-free) + if task.is_null() { + // NULL task signals end of stream + let _ = state_arc.sender.unbounded_send(StreamMessage::EndOfStream); + } else { + // Take ownership of the task by copying it + let task_copy = std::ptr::read(task); + let _ = state_arc.sender.unbounded_send(StreamMessage::Task(task_copy)); } + state_arc.wake(); let _ = Arc::into_raw(state_arc); 0 @@ -450,25 +477,22 @@ unsafe extern "C" fn handler_on_error( return; } - let state_ptr = handler.private_data as *const Mutex; + let state_ptr = handler.private_data as *const ImportedStreamState; let state_arc = Arc::from_raw(state_ptr); - { - let mut state = state_arc.lock().unwrap(); + let error_msg = if message.is_null() { + format!("FFI error code {}", code) + } else { + let c_str = CStr::from_ptr(message); + c_str.to_string_lossy().into_owned() + }; - let error_msg = if message.is_null() { - format!("FFI error code {}", code) - } else { - let c_str = CStr::from_ptr(message); - c_str.to_string_lossy().into_owned() - }; - - state - .messages - .push(StreamMessage::Error(ArrowError::CDataInterface(error_msg))); - state.ended = true; - state.wake(); - } + // Send error through channel (lock-free) + let _ = state_arc + .sender + .unbounded_send(StreamMessage::Error(ArrowError::CDataInterface(error_msg))); + state_arc.ended.store(true, Ordering::Release); + state_arc.wake(); let _ = Arc::into_raw(state_arc); } @@ -488,14 +512,11 @@ unsafe extern "C" fn handler_release(self_: *mut FFI_ArrowAsyncDeviceStreamHandl } // Get the state and clear the producer pointer so stream Drop won't use it - let state_ptr = handler.private_data as *const Mutex; + let state_ptr = handler.private_data as *const ImportedStreamState; let state_arc = Arc::from_raw(state_ptr); - { - let mut state = state_arc.lock().unwrap(); - state.producer = null_mut(); - state.ended = true; - } + state_arc.clear_producer(); + state_arc.ended.store(true, Ordering::Release); // Drop the Arc reference (will be freed when stream also drops its reference) drop(state_arc); From 8283e70e2cabb059869c8e411d9ea73e518b9dfd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 14:08:48 -0500 Subject: [PATCH 12/16] safer high level api contract --- .../import_sendable_record_batch_stream.rs | 146 +++++++++++++++--- 1 file changed, 122 insertions(+), 24 deletions(-) diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index fe041f665d..ea3d554fe1 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -29,10 +29,10 @@ //! ```ignore //! use sedona_extension::import_sendable_record_batch_stream::ImportedAsyncDeviceStream; //! -//! let (stream, handler_ptr) = ImportedAsyncDeviceStream::new(16); +//! let (stream, handler) = ImportedAsyncDeviceStream::new(16); //! -//! // Pass handler_ptr to FFI producer... -//! // ffi_producer_start(handler_ptr); +//! // Pass handler to FFI producer... +//! // ffi_producer_start(handler.as_ptr()); //! //! // Consume the stream //! while let Some(batch) = stream.next().await { @@ -99,6 +99,8 @@ struct ImportedStreamState { pending_requests: AtomicU64, /// Set to true when the stream is dropped but producer hasn't called release yet. abandoned: AtomicBool, + /// Set to true when handler_release is called (handler was freed by producer). + handler_released: AtomicBool, /// Producer state (needs mutex for FFI calls). producer_state: Mutex, } @@ -116,6 +118,7 @@ impl ImportedStreamState { ended: AtomicBool::new(false), pending_requests: AtomicU64::new(0), abandoned: AtomicBool::new(false), + handler_released: AtomicBool::new(false), producer_state: Mutex::new(ProducerState { producer: null_mut(), prefetch_count, @@ -200,28 +203,115 @@ pub struct ImportedAsyncDeviceStream { // Safety: The state uses atomic operations and mutex-protected producer access. unsafe impl Send for ImportedAsyncDeviceStream {} +/// RAII wrapper for the FFI handler. +/// +/// This wrapper ensures the handler is properly cleaned up even if the FFI producer +/// never calls `release()`. It provides safe access to the raw pointer for FFI calls. +/// +/// # Usage +/// +/// ```ignore +/// let (stream, handler) = ImportedAsyncDeviceStream::new(16); +/// +/// // Pass raw pointer to FFI producer +/// ffi_producer_start(handler.as_ptr()); +/// +/// // Handler is automatically cleaned up when dropped (if producer didn't release it) +/// ``` +pub struct AsyncDeviceStreamHandler { + /// Raw pointer to the handler (heap-allocated). + ptr: *mut FFI_ArrowAsyncDeviceStreamHandler, + /// Shared state with the stream - needed to check if already released. + state: Arc, +} + +// Safety: The handler pointer is only accessed from one thread at a time. +// The state Arc provides synchronization. +unsafe impl Send for AsyncDeviceStreamHandler {} + +impl AsyncDeviceStreamHandler { + /// Get the raw pointer to pass to FFI code. + /// + /// The pointer remains valid until either: + /// - The FFI producer calls the handler's `release` callback, or + /// - This wrapper is dropped + #[inline] + pub fn as_ptr(&self) -> *mut FFI_ArrowAsyncDeviceStreamHandler { + self.ptr + } +} + +impl Drop for AsyncDeviceStreamHandler { + fn drop(&mut self) { + if self.ptr.is_null() { + return; + } + + // Check via shared state if the handler was already released by the producer. + // We MUST check this before touching self.ptr because handler_release frees it. + if self.state.handler_released.load(Ordering::Acquire) { + // Producer already called release and freed the handler + return; + } + + // Handler was never released by producer - clean it up ourselves. + // This can happen if the producer never connected or crashed. + // + // We need to: + // 1. Mark the stream as ended + // 2. Drop the Arc reference in private_data + // 3. Free the handler struct + + let handler = unsafe { &mut *self.ptr }; + + self.state.ended.store(true, Ordering::Release); + + if !handler.private_data.is_null() { + let state_ptr = handler.private_data as *const ImportedStreamState; + let _ = unsafe { Arc::from_raw(state_ptr) }; + handler.private_data = null_mut(); + } + + // Clear release to prevent FFI_ArrowAsyncDeviceStreamHandler's Drop + // from calling it (which would double-free) + handler.release = None; + + let _ = unsafe { Box::from_raw(self.ptr) }; + } +} + impl ImportedAsyncDeviceStream { /// Create a new imported stream. /// - /// Returns the stream and a pointer to the handler that should be passed to the FFI producer. + /// Returns the stream and an RAII handler wrapper that should be used to pass + /// the handler pointer to the FFI producer. /// /// # Arguments /// /// * `prefetch_count` - Number of batches to request ahead for back-pressure. /// A larger value reduces latency but uses more memory. /// - /// # Safety + /// # Example + /// + /// ```ignore + /// let (stream, handler) = ImportedAsyncDeviceStream::new(16); /// - /// The returned handler pointer remains valid until the producer calls its `release` - /// callback. When the stream is dropped, it signals cancellation to the producer, - /// but the handler stays valid - callbacks will return error codes to tell the - /// producer to stop and release the handler. - pub fn new(prefetch_count: u64) -> (Self, *mut FFI_ArrowAsyncDeviceStreamHandler) { + /// // Pass to FFI producer + /// ffi_producer_start(handler.as_ptr()); + /// + /// // Stream the data + /// while let Some(batch) = stream.next().await { + /// // ... + /// } + /// // Handler is automatically cleaned up when dropped + /// ``` + pub fn new(prefetch_count: u64) -> (Self, AsyncDeviceStreamHandler) { let (sender, receiver) = mpsc::unbounded(); let state = Arc::new(ImportedStreamState::new(prefetch_count, sender)); - // Handler is heap-allocated and will be freed when producer calls release(). - // This allows the handler to outlive the stream. + // Handler is heap-allocated and will be freed when: + // 1. Producer calls release(), or + // 2. AsyncDeviceStreamHandler is dropped (fallback cleanup) let handler_ptr = Box::into_raw(Box::new(FFI_ArrowAsyncDeviceStreamHandler { on_schema: Some(handler_on_schema), on_next_task: Some(handler_on_next_task), @@ -232,13 +322,18 @@ impl ImportedAsyncDeviceStream { })); let stream = Self { - state, + state: Arc::clone(&state), receiver, schema: None, schema_struct_type: None, }; - (stream, handler_ptr) + let handler = AsyncDeviceStreamHandler { + ptr: handler_ptr, + state, + }; + + (stream, handler) } /// Convert this stream into a [`SendableRecordBatchStream`]. @@ -518,6 +613,9 @@ unsafe extern "C" fn handler_release(self_: *mut FFI_ArrowAsyncDeviceStreamHandl state_arc.clear_producer(); state_arc.ended.store(true, Ordering::Release); + // Mark handler as released so AsyncDeviceStreamHandler::drop knows not to free it + state_arc.handler_released.store(true, Ordering::Release); + // Drop the Arc reference (will be freed when stream also drops its reference) drop(state_arc); handler.private_data = null_mut(); @@ -627,7 +725,7 @@ mod tests { let schema = test_schema(); let source_stream = TestStream::new(schema.clone(), vec![]); - let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(4); let consumer_future = async { let mut received = vec![]; @@ -637,7 +735,7 @@ mod tests { (received, consumer.schema()) }; - let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + let producer_future = drive_stream_to_handler(source_stream, handler.as_ptr()); let ((received, result_schema), _) = futures::join!(consumer_future, producer_future); @@ -651,7 +749,7 @@ mod tests { let batch = make_batch(&schema, 1, 5); let source_stream = TestStream::new(schema.clone(), vec![batch.clone()]); - let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(4); let consumer_future = async { let mut received = vec![]; @@ -661,7 +759,7 @@ mod tests { received }; - let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + let producer_future = drive_stream_to_handler(source_stream, handler.as_ptr()); let (received, _) = futures::join!(consumer_future, producer_future); @@ -693,7 +791,7 @@ mod tests { let source_stream = TestStream::new(schema.clone(), batches); - let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(prefetch_count); + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(prefetch_count); let consumer_future = async { let mut received = vec![]; @@ -703,7 +801,7 @@ mod tests { received }; - let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + let producer_future = drive_stream_to_handler(source_stream, handler.as_ptr()); let (received, _) = futures::join!(consumer_future, producer_future); @@ -720,7 +818,7 @@ mod tests { // Error after first batch let source_stream = TestStream::with_error(schema.clone(), batches, 1); - let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(4); + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(4); let consumer_future = async { let mut received = vec![]; @@ -730,7 +828,7 @@ mod tests { received }; - let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + let producer_future = drive_stream_to_handler(source_stream, handler.as_ptr()); let (received, _) = futures::join!(consumer_future, producer_future); @@ -752,7 +850,7 @@ mod tests { let source_stream = TestStream::new(schema.clone(), batches); - let (mut consumer, handler_ptr) = ImportedAsyncDeviceStream::new(2); + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(2); let consumer_future = async { let mut count = 0; @@ -768,7 +866,7 @@ mod tests { count }; - let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + let producer_future = drive_stream_to_handler(source_stream, handler.as_ptr()); let (count, _) = futures::join!(consumer_future, producer_future); From 8cc758de2e85e97787c4bdfbf02eb2fdc591a526 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 14:11:45 -0500 Subject: [PATCH 13/16] explain lock unwraps --- .../src/export_sendable_record_batch_stream.rs | 6 ++++-- .../src/import_sendable_record_batch_stream.rs | 12 ++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/c/sedona-extension/src/export_sendable_record_batch_stream.rs b/c/sedona-extension/src/export_sendable_record_batch_stream.rs index b0cf7678c4..822f7b9c3e 100644 --- a/c/sedona-extension/src/export_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/export_sendable_record_batch_stream.rs @@ -113,11 +113,13 @@ impl ProducerState { } fn register_waker(&self, waker: Waker) { - *self.waker.lock().unwrap() = Some(waker); + // Lock cannot be poisoned: we never panic while holding it + *self.waker.lock().expect("waker mutex poisoned") = Some(waker); } fn wake(&self) { - if let Some(waker) = self.waker.lock().unwrap().take() { + // Lock cannot be poisoned: we never panic while holding it + if let Some(waker) = self.waker.lock().expect("waker mutex poisoned").take() { waker.wake(); } } diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index ea3d554fe1..8cd156d264 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -136,7 +136,8 @@ impl ImportedStreamState { return; } - let producer_state = self.producer_state.lock().unwrap(); + // Lock cannot be poisoned: we never panic while holding it + let producer_state = self.producer_state.lock().expect("producer_state mutex poisoned"); if producer_state.producer.is_null() { return; } @@ -155,17 +156,20 @@ impl ImportedStreamState { } fn set_producer(&self, producer: *mut FFI_ArrowAsyncProducer) { - let mut state = self.producer_state.lock().unwrap(); + // Lock cannot be poisoned: we never panic while holding it + let mut state = self.producer_state.lock().expect("producer_state mutex poisoned"); state.producer = producer; } fn clear_producer(&self) { - let mut state = self.producer_state.lock().unwrap(); + // Lock cannot be poisoned: we never panic while holding it + let mut state = self.producer_state.lock().expect("producer_state mutex poisoned"); state.producer = null_mut(); } fn cancel(&self) { - let state = self.producer_state.lock().unwrap(); + // Lock cannot be poisoned: we never panic while holding it + let state = self.producer_state.lock().expect("producer_state mutex poisoned"); if !state.producer.is_null() { if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { unsafe { cancel_fn(state.producer) }; From 083293378c907545235bf6cf17fdcecc0856ede8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 14:39:10 -0500 Subject: [PATCH 14/16] test some drop behaviour --- .../import_sendable_record_batch_stream.rs | 190 ++++++++++++++---- 1 file changed, 148 insertions(+), 42 deletions(-) diff --git a/c/sedona-extension/src/import_sendable_record_batch_stream.rs b/c/sedona-extension/src/import_sendable_record_batch_stream.rs index 8cd156d264..04104f69ba 100644 --- a/c/sedona-extension/src/import_sendable_record_batch_stream.rs +++ b/c/sedona-extension/src/import_sendable_record_batch_stream.rs @@ -101,6 +101,8 @@ struct ImportedStreamState { abandoned: AtomicBool, /// Set to true when handler_release is called (handler was freed by producer). handler_released: AtomicBool, + /// Set to true when a producer connects (calls on_schema). + producer_connected: AtomicBool, /// Producer state (needs mutex for FFI calls). producer_state: Mutex, } @@ -119,6 +121,7 @@ impl ImportedStreamState { pending_requests: AtomicU64::new(0), abandoned: AtomicBool::new(false), handler_released: AtomicBool::new(false), + producer_connected: AtomicBool::new(false), producer_state: Mutex::new(ProducerState { producer: null_mut(), prefetch_count, @@ -137,7 +140,10 @@ impl ImportedStreamState { } // Lock cannot be poisoned: we never panic while holding it - let producer_state = self.producer_state.lock().expect("producer_state mutex poisoned"); + let producer_state = self + .producer_state + .lock() + .expect("producer_state mutex poisoned"); if producer_state.producer.is_null() { return; } @@ -150,26 +156,37 @@ impl ImportedStreamState { let to_request = prefetch - pending; if let Some(request_fn) = unsafe { (*producer_state.producer).request } { unsafe { request_fn(producer_state.producer, to_request) }; - self.pending_requests.fetch_add(to_request, Ordering::Release); + self.pending_requests + .fetch_add(to_request, Ordering::Release); } } } fn set_producer(&self, producer: *mut FFI_ArrowAsyncProducer) { // Lock cannot be poisoned: we never panic while holding it - let mut state = self.producer_state.lock().expect("producer_state mutex poisoned"); + let mut state = self + .producer_state + .lock() + .expect("producer_state mutex poisoned"); state.producer = producer; + self.producer_connected.store(true, Ordering::Release); } fn clear_producer(&self) { // Lock cannot be poisoned: we never panic while holding it - let mut state = self.producer_state.lock().expect("producer_state mutex poisoned"); + let mut state = self + .producer_state + .lock() + .expect("producer_state mutex poisoned"); state.producer = null_mut(); } fn cancel(&self) { // Lock cannot be poisoned: we never panic while holding it - let state = self.producer_state.lock().expect("producer_state mutex poisoned"); + let state = self + .producer_state + .lock() + .expect("producer_state mutex poisoned"); if !state.producer.is_null() { if let Some(cancel_fn) = unsafe { (*state.producer).cancel } { unsafe { cancel_fn(state.producer) }; @@ -211,17 +228,6 @@ unsafe impl Send for ImportedAsyncDeviceStream {} /// /// This wrapper ensures the handler is properly cleaned up even if the FFI producer /// never calls `release()`. It provides safe access to the raw pointer for FFI calls. -/// -/// # Usage -/// -/// ```ignore -/// let (stream, handler) = ImportedAsyncDeviceStream::new(16); -/// -/// // Pass raw pointer to FFI producer -/// ffi_producer_start(handler.as_ptr()); -/// -/// // Handler is automatically cleaned up when dropped (if producer didn't release it) -/// ``` pub struct AsyncDeviceStreamHandler { /// Raw pointer to the handler (heap-allocated). ptr: *mut FFI_ArrowAsyncDeviceStreamHandler, @@ -258,8 +264,14 @@ impl Drop for AsyncDeviceStreamHandler { return; } - // Handler was never released by producer - clean it up ourselves. - // This can happen if the producer never connected or crashed. + // If a producer has connected, it will call release when done. + // We must NOT free the handler or we'll cause a use-after-free. + if self.state.producer_connected.load(Ordering::Acquire) { + return; + } + + // No producer ever connected - clean it up ourselves. + // This can happen if the handler was never passed to FFI code. // // We need to: // 1. Mark the stream as ended @@ -294,21 +306,6 @@ impl ImportedAsyncDeviceStream { /// /// * `prefetch_count` - Number of batches to request ahead for back-pressure. /// A larger value reduces latency but uses more memory. - /// - /// # Example - /// - /// ```ignore - /// let (stream, handler) = ImportedAsyncDeviceStream::new(16); - /// - /// // Pass to FFI producer - /// ffi_producer_start(handler.as_ptr()); - /// - /// // Stream the data - /// while let Some(batch) = stream.next().await { - /// // ... - /// } - /// // Handler is automatically cleaned up when dropped - /// ``` pub fn new(prefetch_count: u64) -> (Self, AsyncDeviceStreamHandler) { let (sender, receiver) = mpsc::unbounded(); let state = Arc::new(ImportedStreamState::new(prefetch_count, sender)); @@ -374,9 +371,7 @@ impl ImportedAsyncDeviceStream { } StreamMessage::Task(mut task) => { // Decrement pending (lock-free) - self.state - .pending_requests - .fetch_sub(1, Ordering::Release); + self.state.pending_requests.fetch_sub(1, Ordering::Release); // Maybe request more (acquires lock only if needed) self.state.maybe_request_more(); @@ -507,7 +502,9 @@ unsafe extern "C" fn handler_on_schema( let result = match Schema::try_from(&ffi_schema) { Ok(s) => { // Send through channel (lock-free) - let _ = state_arc.sender.unbounded_send(StreamMessage::Schema(Arc::new(s))); + let _ = state_arc + .sender + .unbounded_send(StreamMessage::Schema(Arc::new(s))); state_arc.wake(); 0 } @@ -553,7 +550,9 @@ unsafe extern "C" fn handler_on_next_task( } else { // Take ownership of the task by copying it let task_copy = std::ptr::read(task); - let _ = state_arc.sender.unbounded_send(StreamMessage::Task(task_copy)); + let _ = state_arc + .sender + .unbounded_send(StreamMessage::Task(task_copy)); } state_arc.wake(); @@ -709,10 +708,7 @@ mod tests { impl Stream for TestStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let this = self.as_mut().get_mut(); Poll::Ready(this.batches.pop_front()) } @@ -877,4 +873,114 @@ mod tests { // Should have received at least 3 batches before cancellation assert!(count >= 3); } + + #[tokio::test] + async fn test_drop_stream_while_producing() { + // Test that dropping the consumer stream while the producer is still producing + // doesn't cause crashes or hangs - the producer should stop gracefully. + let schema = test_schema(); + let batches: Vec = (0..100).map(|i| make_batch(&schema, i * 3, 3)).collect(); + let source_stream = TestStream::new(schema.clone(), batches); + + let (consumer, handler) = ImportedAsyncDeviceStream::new(2); + let handler_ptr = handler.as_ptr(); + + // Wrap consumer in Option so we can drop it mid-stream + let consumer = std::sync::Arc::new(tokio::sync::Mutex::new(Some(consumer))); + let consumer_clone = consumer.clone(); + + let consumer_future = async move { + let mut count = 0; + loop { + let mut guard = consumer_clone.lock().await; + let stream = guard.as_mut().unwrap(); + match stream.next().await { + Some(result) => { + result.expect("should not error before drop"); + count += 1; + if count >= 2 { + // Drop the stream after receiving 2 batches + drop(guard.take()); + break; + } + } + None => break, + } + } + count + }; + + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + // Both futures should complete without panic or hang + let (count, _) = futures::join!(consumer_future, producer_future); + assert_eq!(count, 2); + } + + #[tokio::test] + async fn test_drop_handler_while_consuming() { + // Test that dropping the handler wrapper while the consumer is still reading + // doesn't cause crashes. The handler should be properly freed by the producer's + // release callback, and the wrapper's Drop should be a no-op. + let schema = test_schema(); + let batches: Vec = (0..5).map(|i| make_batch(&schema, i * 3, 3)).collect(); + let source_stream = TestStream::new(schema.clone(), batches); + + let (mut consumer, handler) = ImportedAsyncDeviceStream::new(4); + let handler_ptr = handler.as_ptr(); + + // Start producer + let producer_future = drive_stream_to_handler(source_stream, handler_ptr); + + // Consume all batches, but drop the handler wrapper early + let consumer_future = async { + let mut received = vec![]; + // Read one batch first + if let Some(result) = consumer.next().await { + received.push(result); + } + + // Drop the handler wrapper while stream is still active + // This should NOT free the handler since producer hasn't called release yet + drop(handler); + + // Continue consuming - should work fine + while let Some(result) = consumer.next().await { + received.push(result); + } + received + }; + + let (received, _) = futures::join!(consumer_future, producer_future); + + // All 5 batches should be received + assert_eq!(received.len(), 5); + assert!(received.iter().all(|r| r.is_ok())); + } + + #[tokio::test] + async fn test_handler_cleanup_when_producer_never_connects() { + // Test RAII cleanup when the handler is dropped but the producer never called + // any callbacks (never connected). The handler wrapper should clean up properly. + let (_consumer, handler) = ImportedAsyncDeviceStream::new(4); + + // Just drop the handler without ever passing it to a producer + // This should NOT panic or leak memory + drop(handler); + + // Stream should still be usable (though it will never receive data) + // The ended flag should be set by handler drop + } + + #[tokio::test] + async fn test_stream_and_handler_both_dropped_before_producer() { + // Test cleanup when both stream and handler are dropped before any producer connects + let (consumer, handler) = ImportedAsyncDeviceStream::new(4); + + // Drop both without ever starting a producer + drop(consumer); + drop(handler); + + // Should not panic or leak + } } From a8089cec41458c417f92a3f309347c28f0390382 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 5 Jun 2026 14:40:34 -0500 Subject: [PATCH 15/16] deps and ffi --- c/sedona-extension/Cargo.toml | 3 +-- c/sedona-extension/src/extension_ffi.rs | 32 ------------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/c/sedona-extension/Cargo.toml b/c/sedona-extension/Cargo.toml index 5125e895b5..313de3b6e2 100644 --- a/c/sedona-extension/Cargo.toml +++ b/c/sedona-extension/Cargo.toml @@ -29,7 +29,7 @@ rust-version.workspace = true [features] default = [] -async = ["dep:datafusion-execution", "dep:futures", "dep:tokio"] +async = ["dep:datafusion-execution", "dep:futures"] [dependencies] arrow-array = { workspace = true, features = ["ffi"]} @@ -43,7 +43,6 @@ sedona-common = { workspace = true } sedona-expr = { workspace = true } sedona-schema = { workspace = true } sedona-testing = { path = "../../rust/sedona-testing" } -tokio = { workspace = true, optional = true } [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/c/sedona-extension/src/extension_ffi.rs b/c/sedona-extension/src/extension_ffi.rs index c381ac34f8..92f4709399 100644 --- a/c/sedona-extension/src/extension_ffi.rs +++ b/c/sedona-extension/src/extension_ffi.rs @@ -173,38 +173,6 @@ impl From for FFI_ArrowDeviceArray { } } -/// FFI representation of the ArrowDeviceArrayStream from the Arrow C Device Data Interface -/// -/// See: https://arrow.apache.org/docs/format/CDeviceDataInterface.html -#[repr(C)] -pub struct FFI_ArrowDeviceArrayStream { - pub device_type: i32, - pub get_schema: Option< - unsafe extern "C" fn( - self_: *mut FFI_ArrowDeviceArrayStream, - out: *mut FFI_ArrowSchema, - ) -> c_int, - >, - pub get_next: Option< - unsafe extern "C" fn( - self_: *mut FFI_ArrowDeviceArrayStream, - out: *mut FFI_ArrowDeviceArray, - ) -> c_int, - >, - pub get_last_error: - Option *const c_char>, - pub release: Option, - pub private_data: *mut c_void, -} - -impl Drop for FFI_ArrowDeviceArrayStream { - fn drop(&mut self) { - if let Some(releaser) = self.release { - unsafe { releaser(self) }; - } - } -} - /// FFI representation of the ArrowAsyncProducer from the Arrow C Device Data Interface. /// /// This producer-managed object allows consumers to control flow via back-pressure From ddf495ac97f36f0095b21a61f1cb20da7b6b70d9 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 22 Jun 2026 16:24:48 -0500 Subject: [PATCH 16/16] fmt --- c/sedona-extension/src/scalar_kernel.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/c/sedona-extension/src/scalar_kernel.rs b/c/sedona-extension/src/scalar_kernel.rs index ae06243eaf..6c8620400a 100644 --- a/c/sedona-extension/src/scalar_kernel.rs +++ b/c/sedona-extension/src/scalar_kernel.rs @@ -34,7 +34,9 @@ use std::{ str::FromStr, }; -use crate::extension_ffi::{ffi_arrow_schema_is_valid, SedonaCScalarKernel, SedonaCScalarKernelImpl}; +use crate::extension_ffi::{ + ffi_arrow_schema_is_valid, SedonaCScalarKernel, SedonaCScalarKernelImpl, +}; /// Wrapper around a [SedonaCScalarKernel] that implements [SedonaScalarKernel] ///