Skip to content
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.connect.service

import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.util.concurrent.{TimeoutException, TimeUnit}

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.sys.process.Process
import scala.util.Random
import scala.util.control.NonFatal

import com.google.common.collect.Lists
import org.scalatest.time.SpanSugar._
Expand All @@ -37,8 +39,10 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper}
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -228,15 +232,117 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
assume(PythonTestDepsChecker.isConnectDepsAvailable)
// scalastyle:on assume
// Log and swallow best-effort cleanup failures so they do not mask a primary test
// failure. InterruptedException re-asserts the interrupt flag on the current thread;
// fatal errors (OOM, StackOverflow, LinkageError) propagate.
private def runQuietly(label: String, op: => Unit): Unit = {
try op
catch {
case _: InterruptedException => Thread.currentThread().interrupt()
case NonFatal(t) =>
// scalastyle:off println
println(s"===== $label suppressed ${t.getClass.getSimpleName}: ${t.getMessage} =====")
// scalastyle:on println
}
}

// Same semantics as SparkFunSuite.retry, but prints to stdout so retries show up in the
// GitHub Actions job log (SparkFunSuite.retry's log4j output only lands in
// target/unit-tests.log, surfaced as an artifact rather than in the live log).
private def retryWithVisibleLog(maxAttempts: Int)(body: => Unit): Unit = {
var attempt = 1
var done = false
while (!done) {
try {
body
done = true
} catch {
case NonFatal(t) if attempt >= maxAttempts => throw t
case NonFatal(t) =>
// scalastyle:off println
println(
s"===== Attempt $attempt/$maxAttempts failed " +
s"(${t.getClass.getSimpleName}: ${t.getMessage}); retrying =====")
// scalastyle:on println
// A leaked worker from this attempt may still hold sockets/listeners; do not
// let afterEach/beforeEach throwing on that residual state abort the retry loop.
runQuietly("afterEach", afterEach())
runQuietly("beforeEach", beforeEach())
attempt += 1
}
}
}

private def awaitTestBodyInNewThread(timeoutMillis: Long, onTimeout: () => Unit)(
body: => Unit): Unit = {
@volatile var error: Throwable = null
val runnable: Runnable = () => {
try {
body
} catch {
case t: Throwable => error = t
}
}
val worker = new Thread(runnable, s"${getClass.getSimpleName}-testBody-worker")
worker.setDaemon(true)
worker.start()
worker.join(timeoutMillis)
if (worker.isAlive) {
// Capture the worker's stack so post-mortem diagnostics can identify which leaked
// thread belongs to which attempt without a separate jstack.
// scalastyle:off println
println(
s"===== Test body did not complete within $timeoutMillis ms " +
s"(thread=${worker.getName}, state=${worker.getState}); stack trace follows =====")
worker.getStackTrace.foreach(frame => println(s" at $frame"))
// scalastyle:on println
// Best-effort: release any resource the worker is blocked on so it can unwind its own
// finally and stop holding global state (SparkConnectService, listeners, ...).
onTimeout()
// Also interrupt the worker so any interruptible blocking call (e.g. the Thread.join
// inside StreamExecution.interruptAndAwaitExecutionThreadTermination) wakes up.
worker.interrupt()
// Grace period for the now-unblocked worker to run its own finally
// (SparkConnectService.stop() then the ~4s settle sleep).
val gracePeriodMs = 30.seconds.toMillis
worker.join(gracePeriodMs)
val te = new TimeoutException(
s"Test body did not complete within $timeoutMillis ms " +
s"(after a $gracePeriodMs ms post-cleanup grace period)")
// If the body finished during the grace window, surface the original failure
// as the cause so a slow assertion failure is not misreported as a pure hang.
if (!worker.isAlive && error != null) te.initCause(error)
throw te
}
if (error != null) throw error
}

private def runPythonForeachBatchTerminationTestBody(sessionHolder: SessionHolder): Unit = {
// Unique query names per attempt: a leaked query from a timed-out attempt may still
// occupy the old name in spark.streams.active.
val suffix = s"_${System.nanoTime()}"
val q1Name = s"foreachBatch_termination_test_q1$suffix"
val q2Name = s"foreachBatch_termination_test_q2$suffix"

// Snapshot listeners before this attempt registers anything so we can scope cleanup and
// assertions to listeners we added -- even if a previous timed-out attempt leaked a worker
// whose own finally is racing with us.
val baselineListeners = spark.streams.listListeners().toSet
var capturedServer: AnyRef = null
var ourNewListeners = Set.empty[StreamingQueryListener]

val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
try {
// A previous timed-out attempt's leaked worker may still hold `started=true`, which
// would make `start()` below a no-op and cause this attempt to share (and later
// re-stop) the stale server. Force-stop first so `start()` creates a fresh instance;
// the identity check in `finally` then distinguishes attempts.
if (SparkConnectService.started) {
runQuietly("stale SparkConnectService.stop()", SparkConnectService.stop())
}
SparkConnectService.start(spark.sparkContext)
// Identity-check the server in `finally`: a previous attempt's leaked finally must
// not tear down a service belonging to a later attempt.
capturedServer = SparkConnectService.server

val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
val (fn1, cleaner1) =
Expand All @@ -249,7 +355,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
.load()
.writeStream
.format("memory")
.queryName("foreachBatch_termination_test_q1")
.queryName(q1Name)
.foreachBatch(fn1)
.start()

Expand All @@ -258,7 +364,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
.load()
.writeStream
.format("memory")
.queryName("foreachBatch_termination_test_q2")
.queryName(q2Name)
.foreachBatch(fn2)
.start()

Expand All @@ -267,6 +373,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
sessionHolder.streamingForeachBatchRunnerCleanerCache
.registerCleanerForQuery(query2, cleaner2)

// The first registerCleanerForQuery lazily registers the cleaner listener. Capture the
// listeners we added so finally only removes ours, not a concurrent attempt's.
ourNewListeners = spark.streams.listListeners().toSet -- baselineListeners

val (runner1, runner2) =
(cleaner1.asInstanceOf[RunnerCleaner].runner, cleaner2.asInstanceOf[RunnerCleaner].runner)

Expand All @@ -288,14 +398,58 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
assert(runner2.isWorkerStopped().get)
}

assert(spark.streams.active.isEmpty) // no running query
assert(spark.streams.listListeners().length == 1) // only process termination listener
// Only assert this attempt's queries stopped; a previous timed-out attempt may have
// leaked queries into spark.streams.active that we cannot synchronously clean up.
assert(!spark.streams.active.exists(q => q.name == q1Name || q.name == q2Name))
// Scoped to this attempt: exactly one new listener (the cleaner listener) should
// have been registered, regardless of any listeners leaked by a prior attempt.
assert(
ourNewListeners.size == 1,
s"expected exactly 1 new listener registered by this attempt, " +
s"got ${ourNewListeners.size}")
} finally {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
// Only stop the service if it is still the one this attempt started; otherwise a
// previous attempt's leaked finally would tear down the live service of the current
// attempt.
if (capturedServer != null && (SparkConnectService.server eq capturedServer)) {
// Cleanup is best-effort: any failure must not mask the primary failure in the
// try block, and the listener cleanup below must still run.
runQuietly("SparkConnectService.stop()", SparkConnectService.stop())
runQuietly("settle sleep", Thread.sleep(4.seconds.toMillis))
}
// Remove only the listeners this attempt registered; never touch a concurrent
// attempt's process-termination listener. Wrapped in `runQuietly` so a throw here
// cannot mask a primary failure in the try block.
runQuietly("removeListeners", ourNewListeners.foreach(spark.streams.removeListener))
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
assume(PythonTestDepsChecker.isConnectDepsAvailable)
// scalastyle:on assume

// Bound query.stop() so it cannot hang indefinitely: spark.sql.streaming.stopTimeout
// defaults to 0 (wait forever), which turns a stuck batch into an unkillable test.
// 30s is small enough to fit under the outer per-attempt cap with room to spare.
withSQLConf(SQLConf.STREAMING_STOP_TIMEOUT.key -> "30000") {
retryWithVisibleLog(maxAttempts = 3) {
// Run the body on a fresh daemon thread so the test thread can recover from a
// hang in a non-interruptible socket read. SessionHolder is created outside the
// body so onTimeout can close its Python worker sockets via cleanerCache; that
// unblocks the hung dataIn.readInt so the leaked thread's finally can settle
// before the next retry. 2-minute cap strictly bounds the original 150-minute hang.
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
awaitTestBodyInNewThread(
timeoutMillis = TimeUnit.MINUTES.toMillis(2),
onTimeout = () =>
runQuietly(
"onTimeout cleanUpAll",
sessionHolder.streamingForeachBatchRunnerCleanerCache.cleanUpAll())) {
runPythonForeachBatchTerminationTestBody(sessionHolder)
}
}
}
}

Expand Down