diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 17402ab5ddb43..cff5f345d2573 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.connect.service import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.sys.process.Process import scala.util.Random +import scala.util.control.NonFatal import com.google.common.collect.Lists import org.scalatest.time.SpanSugar._ @@ -37,8 +39,10 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper} import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl} import org.apache.spark.sql.pipelines.logging.PipelineEvent +import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ArrayImplicits._ @@ -228,15 +232,117 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { } } - test("python foreachBatch process: process terminates after query is stopped") { - // scalastyle:off assume - assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) - assume(PythonTestDepsChecker.isConnectDepsAvailable) - // scalastyle:on assume + // Log and swallow best-effort cleanup failures so they do not mask a primary test + // failure. InterruptedException re-asserts the interrupt flag on the current thread; + // fatal errors (OOM, StackOverflow, LinkageError) propagate. + private def runQuietly(label: String, op: => Unit): Unit = { + try op + catch { + case _: InterruptedException => Thread.currentThread().interrupt() + case NonFatal(t) => + // scalastyle:off println + println(s"===== $label suppressed ${t.getClass.getSimpleName}: ${t.getMessage} =====") + // scalastyle:on println + } + } + + // Same semantics as SparkFunSuite.retry, but prints to stdout so retries show up in the + // GitHub Actions job log (SparkFunSuite.retry's log4j output only lands in + // target/unit-tests.log, surfaced as an artifact rather than in the live log). + private def retryWithVisibleLog(maxAttempts: Int)(body: => Unit): Unit = { + var attempt = 1 + var done = false + while (!done) { + try { + body + done = true + } catch { + case NonFatal(t) if attempt >= maxAttempts => throw t + case NonFatal(t) => + // scalastyle:off println + println( + s"===== Attempt $attempt/$maxAttempts failed " + + s"(${t.getClass.getSimpleName}: ${t.getMessage}); retrying =====") + // scalastyle:on println + // A leaked worker from this attempt may still hold sockets/listeners; do not + // let afterEach/beforeEach throwing on that residual state abort the retry loop. + runQuietly("afterEach", afterEach()) + runQuietly("beforeEach", beforeEach()) + attempt += 1 + } + } + } + + private def awaitTestBodyInNewThread(timeoutMillis: Long, onTimeout: () => Unit)( + body: => Unit): Unit = { + @volatile var error: Throwable = null + val runnable: Runnable = () => { + try { + body + } catch { + case t: Throwable => error = t + } + } + val worker = new Thread(runnable, s"${getClass.getSimpleName}-testBody-worker") + worker.setDaemon(true) + worker.start() + worker.join(timeoutMillis) + if (worker.isAlive) { + // Capture the worker's stack so post-mortem diagnostics can identify which leaked + // thread belongs to which attempt without a separate jstack. + // scalastyle:off println + println( + s"===== Test body did not complete within $timeoutMillis ms " + + s"(thread=${worker.getName}, state=${worker.getState}); stack trace follows =====") + worker.getStackTrace.foreach(frame => println(s" at $frame")) + // scalastyle:on println + // Best-effort: release any resource the worker is blocked on so it can unwind its own + // finally and stop holding global state (SparkConnectService, listeners, ...). + onTimeout() + // Also interrupt the worker so any interruptible blocking call (e.g. the Thread.join + // inside StreamExecution.interruptAndAwaitExecutionThreadTermination) wakes up. + worker.interrupt() + // Grace period for the now-unblocked worker to run its own finally + // (SparkConnectService.stop() then the ~4s settle sleep). + val gracePeriodMs = 30.seconds.toMillis + worker.join(gracePeriodMs) + val te = new TimeoutException( + s"Test body did not complete within $timeoutMillis ms " + + s"(after a $gracePeriodMs ms post-cleanup grace period)") + // If the body finished during the grace window, surface the original failure + // as the cause so a slow assertion failure is not misreported as a pure hang. + if (!worker.isAlive && error != null) te.initCause(error) + throw te + } + if (error != null) throw error + } + + private def runPythonForeachBatchTerminationTestBody(sessionHolder: SessionHolder): Unit = { + // Unique query names per attempt: a leaked query from a timed-out attempt may still + // occupy the old name in spark.streams.active. + val suffix = s"_${System.nanoTime()}" + val q1Name = s"foreachBatch_termination_test_q1$suffix" + val q2Name = s"foreachBatch_termination_test_q2$suffix" + + // Snapshot listeners before this attempt registers anything so we can scope cleanup and + // assertions to listeners we added -- even if a previous timed-out attempt leaked a worker + // whose own finally is racing with us. + val baselineListeners = spark.streams.listListeners().toSet + var capturedServer: AnyRef = null + var ourNewListeners = Set.empty[StreamingQueryListener] - val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) try { + // A previous timed-out attempt's leaked worker may still hold `started=true`, which + // would make `start()` below a no-op and cause this attempt to share (and later + // re-stop) the stale server. Force-stop first so `start()` creates a fresh instance; + // the identity check in `finally` then distinguishes attempts. + if (SparkConnectService.started) { + runQuietly("stale SparkConnectService.stop()", SparkConnectService.stop()) + } SparkConnectService.start(spark.sparkContext) + // Identity-check the server in `finally`: a previous attempt's leaked finally must + // not tear down a service belonging to a later attempt. + capturedServer = SparkConnectService.server val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction) val (fn1, cleaner1) = @@ -249,7 +355,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { .load() .writeStream .format("memory") - .queryName("foreachBatch_termination_test_q1") + .queryName(q1Name) .foreachBatch(fn1) .start() @@ -258,7 +364,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { .load() .writeStream .format("memory") - .queryName("foreachBatch_termination_test_q2") + .queryName(q2Name) .foreachBatch(fn2) .start() @@ -267,6 +373,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { sessionHolder.streamingForeachBatchRunnerCleanerCache .registerCleanerForQuery(query2, cleaner2) + // The first registerCleanerForQuery lazily registers the cleaner listener. Capture the + // listeners we added so finally only removes ours, not a concurrent attempt's. + ourNewListeners = spark.streams.listListeners().toSet -- baselineListeners + val (runner1, runner2) = (cleaner1.asInstanceOf[RunnerCleaner].runner, cleaner2.asInstanceOf[RunnerCleaner].runner) @@ -288,14 +398,58 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { assert(runner2.isWorkerStopped().get) } - assert(spark.streams.active.isEmpty) // no running query - assert(spark.streams.listListeners().length == 1) // only process termination listener + // Only assert this attempt's queries stopped; a previous timed-out attempt may have + // leaked queries into spark.streams.active that we cannot synchronously clean up. + assert(!spark.streams.active.exists(q => q.name == q1Name || q.name == q2Name)) + // Scoped to this attempt: exactly one new listener (the cleaner listener) should + // have been registered, regardless of any listeners leaked by a prior attempt. + assert( + ourNewListeners.size == 1, + s"expected exactly 1 new listener registered by this attempt, " + + s"got ${ourNewListeners.size}") } finally { - SparkConnectService.stop() - // Wait for things to calm down. - Thread.sleep(4.seconds.toMillis) - // remove process termination listener - spark.streams.listListeners().foreach(spark.streams.removeListener) + // Only stop the service if it is still the one this attempt started; otherwise a + // previous attempt's leaked finally would tear down the live service of the current + // attempt. + if (capturedServer != null && (SparkConnectService.server eq capturedServer)) { + // Cleanup is best-effort: any failure must not mask the primary failure in the + // try block, and the listener cleanup below must still run. + runQuietly("SparkConnectService.stop()", SparkConnectService.stop()) + runQuietly("settle sleep", Thread.sleep(4.seconds.toMillis)) + } + // Remove only the listeners this attempt registered; never touch a concurrent + // attempt's process-termination listener. Wrapped in `runQuietly` so a throw here + // cannot mask a primary failure in the try block. + runQuietly("removeListeners", ourNewListeners.foreach(spark.streams.removeListener)) + } + } + + test("python foreachBatch process: process terminates after query is stopped") { + // scalastyle:off assume + assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) + assume(PythonTestDepsChecker.isConnectDepsAvailable) + // scalastyle:on assume + + // Bound query.stop() so it cannot hang indefinitely: spark.sql.streaming.stopTimeout + // defaults to 0 (wait forever), which turns a stuck batch into an unkillable test. + // 30s is small enough to fit under the outer per-attempt cap with room to spare. + withSQLConf(SQLConf.STREAMING_STOP_TIMEOUT.key -> "30000") { + retryWithVisibleLog(maxAttempts = 3) { + // Run the body on a fresh daemon thread so the test thread can recover from a + // hang in a non-interruptible socket read. SessionHolder is created outside the + // body so onTimeout can close its Python worker sockets via cleanerCache; that + // unblocks the hung dataIn.readInt so the leaked thread's finally can settle + // before the next retry. 2-minute cap strictly bounds the original 150-minute hang. + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + awaitTestBodyInNewThread( + timeoutMillis = TimeUnit.MINUTES.toMillis(2), + onTimeout = () => + runQuietly( + "onTimeout cleanUpAll", + sessionHolder.streamingForeachBatchRunnerCleanerCache.cleanUpAll())) { + runPythonForeachBatchTerminationTestBody(sessionHolder) + } + } } }