diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala index d546eebf4c1b..bde8f028c57e 100644 --- a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala @@ -47,6 +47,8 @@ case class PaimonBatchWrite( protected val metricRegistry = SparkMetricRegistry() + @volatile private var commitStarted: Boolean = false + protected val batchWriteBuilder: BatchWriteBuilder = { val builder = table.newBatchWriteBuilder() overwritePartitions.foreach(partitions => builder.withOverwrite(partitions.asJava)) @@ -68,6 +70,7 @@ case class PaimonBatchWrite( override def useCommitCoordinator(): Boolean = false override def commit(messages: Array[WriterCommitMessage]): Unit = { + commitStarted = true logInfo(s"Committing to table ${table.name()}") val batchTableCommit = batchWriteBuilder.newCommit() batchTableCommit.withMetricRegistry(metricRegistry) @@ -107,7 +110,19 @@ case class PaimonBatchWrite( } override def abort(messages: Array[WriterCommitMessage]): Unit = { - // TODO clean uncommitted files + if (commitStarted) { + logWarning(s"Skip abort cleanup for table ${table.name()} because commit has already started") + return + } + + logInfo(s"Aborting write to table ${table.name()}") + val batchTableCommit = batchWriteBuilder.newCommit() + try { + val commitMessages = WriteTaskResult.merge(messages.filter(_ != null)) + batchTableCommit.abort(commitMessages.asJava) + } finally { + batchTableCommit.close() + } } private def buildDeletedCommitMessage( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala index 92aeae031343..1f2abae0b01c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonBatchWrite.scala @@ -47,6 +47,8 @@ case class PaimonBatchWrite( protected val metricRegistry = SparkMetricRegistry() + @volatile private var commitStarted: Boolean = false + protected val batchWriteBuilder: BatchWriteBuilder = { val builder = table.newBatchWriteBuilder() overwritePartitions.foreach(partitions => builder.withOverwrite(partitions.asJava)) @@ -68,6 +70,7 @@ case class PaimonBatchWrite( override def useCommitCoordinator(): Boolean = false override def commit(messages: Array[WriterCommitMessage]): Unit = { + commitStarted = true logInfo(s"Committing to table ${table.name()}") val batchTableCommit = batchWriteBuilder.newCommit() batchTableCommit.withMetricRegistry(metricRegistry) @@ -107,7 +110,19 @@ case class PaimonBatchWrite( } override def abort(messages: Array[WriterCommitMessage]): Unit = { - // TODO clean uncommitted files + if (commitStarted) { + logWarning(s"Skip abort cleanup for table ${table.name()} because commit has already started") + return + } + + logInfo(s"Aborting write to table ${table.name()}") + val batchTableCommit = batchWriteBuilder.newCommit() + try { + val commitMessages = WriteTaskResult.merge(messages.filter(_ != null)) + batchTableCommit.abort(commitMessages.asJava) + } finally { + batchTableCommit.close() + } } private def buildDeletedCommitMessage( diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala index b2ae78f1ce5e..7496c1785431 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkWriteITCase.scala @@ -23,16 +23,21 @@ import org.apache.paimon.CoreOptions.BucketFunctionType import org.apache.paimon.catalog.Identifier import org.apache.paimon.schema.Schema import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.write.{PaimonBatchWrite, WriteTaskResult} +import org.apache.paimon.table.sink.CommitMessageImpl import org.apache.paimon.types.DataTypes import org.apache.spark.SparkConf import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Assertions import java.sql.Timestamp import java.time.LocalDateTime +import scala.collection.JavaConverters._ + class SparkWriteWithNoExtensionITCase extends SparkWriteITCase { /** Disable the spark extension. */ @@ -272,6 +277,79 @@ class SparkWriteITCase extends PaimonSparkTestBase { } } + test("Paimon Write: abort cleans uncommitted files") { + withTable("T") { + spark.sql( + "CREATE TABLE T (id INT, data INT) TBLPROPERTIES ('bucket' = '1', 'bucket-key' = 'id')") + + val table = loadTable("T") + val sparkSchema = spark.table("T").schema + val batchWrite = PaimonBatchWrite(table, sparkSchema, sparkSchema, None, None) + val dataWriter = batchWrite.createBatchWriterFactory(null).createWriter(0, 0L) + + dataWriter.write(new GenericInternalRow(Array[Any](1, 10))) + val writerCommitMessage = dataWriter.commit() + val dataFilePaths = dataFilePathsFromWriteTaskResult(table, writerCommitMessage) + + assertThat(dataFilePaths.size).isGreaterThan(0) + dataFilePaths.foreach(path => assertThat(table.fileIO().exists(path)).isTrue) + + batchWrite.abort(Array(writerCommitMessage, null)) + + dataFilePaths.foreach(path => assertThat(table.fileIO().exists(path)).isFalse) + assertThat(table.latestSnapshot()).isEmpty + } + } + + test("Paimon Write: abort skips cleanup after commit starts") { + withTable("T") { + spark.sql( + "CREATE TABLE T (id INT, data INT) TBLPROPERTIES ('bucket' = '1', 'bucket-key' = 'id')") + + val table = loadTable("T") + val sparkSchema = spark.table("T").schema + val batchWrite = PaimonBatchWrite(table, sparkSchema, sparkSchema, None, None) + val dataWriter = batchWrite.createBatchWriterFactory(null).createWriter(0, 0L) + + dataWriter.write(new GenericInternalRow(Array[Any](1, 10))) + val writerCommitMessage = dataWriter.commit() + val dataFilePaths = dataFilePathsFromWriteTaskResult(table, writerCommitMessage) + + assertThat(dataFilePaths.size).isGreaterThan(0) + dataFilePaths.foreach(path => assertThat(table.fileIO().exists(path)).isTrue) + + try { + batchWrite.commit(Array(writerCommitMessage)) + } catch { + case _: Throwable => + } + assertThat(table.latestSnapshot()).isPresent + + batchWrite.abort(Array(writerCommitMessage, null)) + + dataFilePaths.foreach(path => assertThat(table.fileIO().exists(path)).isTrue) + checkAnswer(spark.sql("SELECT * FROM T"), Row(1, 10) :: Nil) + } + } + + private def dataFilePathsFromWriteTaskResult( + table: org.apache.paimon.table.FileStoreTable, + writerCommitMessage: org.apache.spark.sql.connector.write.WriterCommitMessage) = { + WriteTaskResult.merge(Seq(writerCommitMessage)).flatMap { + case commitMessage: CommitMessageImpl => + val pathFactory = table + .store() + .pathFactory() + .createDataFilePathFactory(commitMessage.partition(), commitMessage.bucket()) + commitMessage + .newFilesIncrement() + .newFiles() + .asScala + .map(pathFactory.toPath) + case _ => Seq.empty + } + } + test("Paimon write: write table with timestamp3 bucket key") { withTable("t") { // create timestamp3 table using table api