Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
.toString
}

def extractChangeFeedStateForRanges
(
changeFeedState: ChangeFeedState,
feedRanges: Seq[NormalizedRange]
): Seq[String] = {
val cosmosRanges = feedRanges.map(toCosmosRange).asJava
changeFeedState
.extractForEffectiveRanges(cosmosRanges)
.asScala
.map(_.toString)
}

def parseChangeFeedState(changeFeedStateJsonString: String): ChangeFeedState = {
assert(!Strings.isNullOrWhiteSpace(changeFeedStateJsonString), s"Argument 'changeFeedStateJsonString' must not be null or empty.")
ChangeFeedState.fromString(changeFeedStateJsonString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,16 @@ private class ChangeFeedBatch
// Latest offset above has the EndLsn specified based on the point-in-time latest offset
// For batch mode instead we need to reset it so that the change feed will get fully drained
val parsedInitialOffset = SparkBridgeImplementationInternal.parseChangeFeedState(initialOffsetJson)
val inputPartitions = latestOffset
.inputPartitions
.get
.map(partition => partition
.withContinuationState(
SparkBridgeImplementationInternal
.extractChangeFeedStateForRange(parsedInitialOffset, partition.feedRange),
clearEndLsn = !hasBatchCheckpointLocation))
.map(_.asInstanceOf[InputPartition])
val partitions = latestOffset.inputPartitions.get
val continuationStates = SparkBridgeImplementationInternal
.extractChangeFeedStateForRanges(parsedInitialOffset, partitions.map(_.feedRange))
val inputPartitions = partitions
.zip(continuationStates)
.map { case (partition, continuationState) =>
partition
.withContinuationState(continuationState, clearEndLsn = !hasBatchCheckpointLocation)
.asInstanceOf[InputPartition]
}

log.logInfo(s"<-- planInputPartitions $batchId (creating ${inputPartitions.length} partitions)")
inputPartitions
Expand Down
Loading
Loading