diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 72d9fa566deef..15c5229a81980 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -745,6 +745,30 @@ def iter_f(it): actual = df.select(g(f(col("id"))).alias("struct")).collect() self.assertEqual(expected, actual) + def test_scalar_iter_udf_struct_input_and_output(self): + return_type = StructType([StructField("id", LongType()), StructField("str", StringType())]) + + @pandas_udf(return_type, PandasUDFType.SCALAR_ITER) + def iter_struct(it): + for s in it: + if not isinstance(s, pd.DataFrame): + raise TypeError(type(s).__name__) + yield pd.DataFrame({"id": s["id"] + 1, "str": s["str"].str.upper()}) + + df = self.spark.range(3).select( + struct(col("id"), col("id").cast("string").alias("str")).alias("s") + ) + expected = [ + Row(out=Row(id=1, str="0")), + Row(out=Row(id=2, str="1")), + Row(out=Row(id=3, str="2")), + ] + + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1}): + actual = df.select(iter_struct("s").alias("out")).collect() + + self.assertEqual(expected, actual) + def test_vectorized_udf_wrong_return_type(self): with self.quiet(): self.check_vectorized_udf_wrong_return_type() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3b96f02e04b7b..a2e3aa932b2f6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1021,7 +1021,7 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: - return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) + return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: @@ -2344,15 +2344,13 @@ def read_udfs(pickleSer, udf_info_list, eval_type, runner_conf, eval_conf): PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, ): ser = ArrowStreamSerializer(write_start_stream=True) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. - df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - ) + df_for_struct = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ser = ArrowStreamPandasUDFSerializer( timezone=runner_conf.timezone, @@ -3222,26 +3220,118 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record # profiling is not supported for UDF return func, None, ser, ser - is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: + import pandas as pd + import pyarrow as pa - if is_scalar_iter or is_map_pandas_iter: - # TODO: Better error message for num_udfs != 1 - if is_scalar_iter: - assert num_udfs == 1, "One SCALAR_ITER UDF expected here." - if is_map_pandas_iter: - assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + assert num_udfs == 1, "One SCALAR_ITER UDF expected here." + udf_func, args_offsets, _, return_type = udfs[0] + output_schema = StructType([StructField("_0", return_type)]) + expected_type = pd.DataFrame if isinstance(return_type, StructType) else pd.Series + expected_label = ( + "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series" + ) - arg_offsets, udf = udfs[0] + def verify_result(result): + if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": f"iterator of {expected_label}", + "actual": type(result).__name__, + }, + ) + return result - def func(_, iterator): # type: ignore[misc] + def verify_element(element): + if not isinstance(element, expected_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": f"iterator of {expected_label}", + "actual": f"iterator of {type(element).__name__}", + }, + ) + + verify_pandas_result( + element, + return_type, + assign_cols_by_name=True, + truncate_return_schema=True, + ) + return element + + def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: num_input_rows = 0 - def map_batch(batch): + def extract_args(batch: pa.RecordBatch): nonlocal num_input_rows + pandas_columns = ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + struct_in_pandas="dict", + ndarray_as_list=False, + prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, + df_for_struct=True, + ) + args = tuple(pandas_columns[o] for o in args_offsets) + num_input_rows += batch.num_rows + return args[0] if len(args) == 1 else args + + args_iter = map(extract_args, data) + result_iter = verify_result(udf_func(args_iter)) + + num_output_rows = 0 + for result in map(verify_element, result_iter): + num_output_rows += len(result) + # Fail fast if the scalar iterator UDF yields more rows than it has consumed. + if num_output_rows > num_input_rows: + raise PySparkRuntimeError( + errorClass="OUTPUT_EXCEEDS_INPUT_ROWS", + messageParameters={}, + ) + yield PandasToArrowConversion.convert( + [result], + output_schema, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + arrow_cast=True, + prefers_large_types=runner_conf.use_large_var_types, + assign_cols_by_name=runner_conf.assign_cols_by_name, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + + try: + next(args_iter) + except StopIteration: + pass + else: + raise PySparkRuntimeError( + errorClass="INPUT_NOT_FULLY_CONSUMED", + messageParameters={}, + ) + + if num_output_rows != num_input_rows: + raise PySparkRuntimeError( + errorClass="RESULT_ROWS_MISMATCH", + messageParameters={ + "output_length": str(num_output_rows), + "input_length": str(num_input_rows), + }, + ) + + # profiling is not supported for UDF + return func, None, ser, ser + if eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: + # TODO: Better error message for num_udfs != 1 + assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + + arg_offsets, udf = udfs[0] + + def func(_, iterator): # type: ignore[misc] + def map_batch(batch): udf_args = [batch[offset] for offset in arg_offsets] - num_input_rows += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: @@ -3250,40 +3340,9 @@ def map_batch(batch): iterator = map(map_batch, iterator) result_iter = udf(iterator) - num_output_rows = 0 for result_batch, result_type in result_iter: - num_output_rows += len(result_batch) - # This check is for Scalar Iterator UDF to fail fast. - # The length of the entire input can only be explicitly known - # by consuming the input iterator in user side. Therefore, - # it's very unlikely the output length is higher than - # input length. - if is_scalar_iter and num_output_rows > num_input_rows: - raise PySparkRuntimeError( - errorClass="OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} - ) yield (result_batch, result_type) - if is_scalar_iter: - try: - next(iterator) - except StopIteration: - pass - else: - raise PySparkRuntimeError( - errorClass="INPUT_NOT_FULLY_CONSUMED", - messageParameters={}, - ) - - if num_output_rows != num_input_rows: - raise PySparkRuntimeError( - errorClass="RESULT_ROWS_MISMATCH", - messageParameters={ - "output_length": str(num_output_rows), - "input_length": str(num_input_rows), - }, - ) - # profiling is not supported for UDF return func, None, ser, ser