Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,30 @@ def iter_f(it):
actual = df.select(g(f(col("id"))).alias("struct")).collect()
self.assertEqual(expected, actual)

def test_scalar_iter_udf_struct_input_and_output(self):
return_type = StructType([StructField("id", LongType()), StructField("str", StringType())])

@pandas_udf(return_type, PandasUDFType.SCALAR_ITER)
def iter_struct(it):
for s in it:
if not isinstance(s, pd.DataFrame):
raise TypeError(type(s).__name__)
yield pd.DataFrame({"id": s["id"] + 1, "str": s["str"].str.upper()})

df = self.spark.range(3).select(
struct(col("id"), col("id").cast("string").alias("str")).alias("s")
)
expected = [
Row(out=Row(id=1, str="0")),
Row(out=Row(id=2, str="1")),
Row(out=Row(id=3, str="2")),
]

with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1}):
actual = df.select(iter_struct("s").alias("out")).collect()

self.assertEqual(expected, actual)

def test_vectorized_udf_wrong_return_type(self):
with self.quiet():
self.check_vectorized_udf_wrong_return_type()
Expand Down
155 changes: 107 additions & 48 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf)
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF:
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
Expand Down Expand Up @@ -2344,15 +2344,13 @@ def read_udfs(pickleSer, udf_info_list, eval_type, runner_conf, eval_conf):
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
):
ser = ArrowStreamSerializer(write_start_stream=True)
else:
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
)
df_for_struct = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

ser = ArrowStreamPandasUDFSerializer(
timezone=runner_conf.timezone,
Expand Down Expand Up @@ -3222,26 +3220,118 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
# profiling is not supported for UDF
return func, None, ser, ser

is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
import pandas as pd
import pyarrow as pa

if is_scalar_iter or is_map_pandas_iter:
# TODO: Better error message for num_udfs != 1
if is_scalar_iter:
assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
if is_map_pandas_iter:
assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
udf_func, args_offsets, _, return_type = udfs[0]
output_schema = StructType([StructField("_0", return_type)])
expected_type = pd.DataFrame if isinstance(return_type, StructType) else pd.Series
expected_label = (
"pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series"
)

arg_offsets, udf = udfs[0]
def verify_result(result):
if not isinstance(result, Iterator) and not hasattr(result, "__iter__"):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": f"iterator of {expected_label}",
"actual": type(result).__name__,
},
)
return result

def func(_, iterator): # type: ignore[misc]
def verify_element(element):
if not isinstance(element, expected_type):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": f"iterator of {expected_label}",
"actual": f"iterator of {type(element).__name__}",
},
)

verify_pandas_result(
element,
return_type,
assign_cols_by_name=True,
truncate_return_schema=True,
)
return element

def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
num_input_rows = 0

def map_batch(batch):
def extract_args(batch: pa.RecordBatch):
nonlocal num_input_rows
pandas_columns = ArrowBatchTransformer.to_pandas(
batch,
timezone=runner_conf.timezone,
struct_in_pandas="dict",
ndarray_as_list=False,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
df_for_struct=True,
)
args = tuple(pandas_columns[o] for o in args_offsets)
num_input_rows += batch.num_rows
return args[0] if len(args) == 1 else args

args_iter = map(extract_args, data)
result_iter = verify_result(udf_func(args_iter))

num_output_rows = 0
for result in map(verify_element, result_iter):
num_output_rows += len(result)
# Fail fast if the scalar iterator UDF yields more rows than it has consumed.
if num_output_rows > num_input_rows:
raise PySparkRuntimeError(
errorClass="OUTPUT_EXCEEDS_INPUT_ROWS",
messageParameters={},
)
yield PandasToArrowConversion.convert(
[result],
output_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)

try:
next(args_iter)
except StopIteration:
pass
else:
raise PySparkRuntimeError(
errorClass="INPUT_NOT_FULLY_CONSUMED",
messageParameters={},
)

if num_output_rows != num_input_rows:
raise PySparkRuntimeError(
errorClass="RESULT_ROWS_MISMATCH",
messageParameters={
"output_length": str(num_output_rows),
"input_length": str(num_input_rows),
},
)

# profiling is not supported for UDF
return func, None, ser, ser

if eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
# TODO: Better error message for num_udfs != 1
assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."

arg_offsets, udf = udfs[0]

def func(_, iterator): # type: ignore[misc]
def map_batch(batch):
udf_args = [batch[offset] for offset in arg_offsets]
num_input_rows += len(udf_args[0])
if len(udf_args) == 1:
return udf_args[0]
else:
Expand All @@ -3250,40 +3340,9 @@ def map_batch(batch):
iterator = map(map_batch, iterator)
result_iter = udf(iterator)

num_output_rows = 0
for result_batch, result_type in result_iter:
num_output_rows += len(result_batch)
# This check is for Scalar Iterator UDF to fail fast.
# The length of the entire input can only be explicitly known
# by consuming the input iterator in user side. Therefore,
# it's very unlikely the output length is higher than
# input length.
if is_scalar_iter and num_output_rows > num_input_rows:
raise PySparkRuntimeError(
errorClass="OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={}
)
yield (result_batch, result_type)

if is_scalar_iter:
try:
next(iterator)
except StopIteration:
pass
else:
raise PySparkRuntimeError(
errorClass="INPUT_NOT_FULLY_CONSUMED",
messageParameters={},
)

if num_output_rows != num_input_rows:
raise PySparkRuntimeError(
errorClass="RESULT_ROWS_MISMATCH",
messageParameters={
"output_length": str(num_output_rows),
"input_length": str(num_input_rows),
},
)

# profiling is not supported for UDF
return func, None, ser, ser

Expand Down