From d8c72f5b40a1d389d9aca51b9f8f92e7111a7650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Tue, 26 May 2026 15:05:00 +0200 Subject: [PATCH 1/5] [swiftsrc2cpg] De-sugar optional binding constructs --- .../swiftsrc2cpg/astcreation/AstCreator.scala | 6 +- .../astcreation/AstCreatorHelper.scala | 371 +++++++++++++++++- .../astcreation/AstForExprSyntaxCreator.scala | 143 +++++-- .../astcreation/AstForStmtSyntaxCreator.scala | 253 ++++++++++-- .../astcreation/AstNodeBuilder.scala | 5 - .../passes/ast/AvailabilityQueryTests.scala | 17 +- .../swiftsrc2cpg/passes/ast/GuardTests.scala | 202 +++++++++- .../passes/ast/GuardTopLevelTests.scala | 17 +- .../passes/ast/StatementTests.scala | 361 ++++++++++++++++- .../swiftsrc2cpg/passes/ast/WhileTests.scala | 6 +- 10 files changed, 1289 insertions(+), 92 deletions(-) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala index 905c1c905e01..84085d061847 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala @@ -58,9 +58,8 @@ class AstCreator( fileContent.foreach(fileNode.content(_)) val namespaceBlock = globalNamespaceBlock() methodAstParentStack.push(namespaceBlock) - val astForFakeMethod = - astInFakeMethod(namespaceBlock.fullName, parserResult.filename, parserResult.ast) - val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(astForFakeMethod)) + val astForFakeMethod = astInFakeMethod(namespaceBlock.fullName, parserResult.filename, parserResult.ast) + val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(astForFakeMethod)) Ast.storeInDiffGraph(ast, diffGraph) scope.createVariableReferenceLinks(diffGraph, parserResult.filename) diffGraph @@ -125,4 +124,5 @@ class AstCreator( case _ => shortenCode(new String(code)).stripLineEnd } } + } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index cc2fef462721..3af9bb4e5a0b 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -3,7 +3,10 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, EvaluationStrategies, PropertyNames} +import io.joern.x2cpg.datastructures.Stack.* +import io.joern.x2cpg.datastructures.VariableScopeManager +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* import org.apache.commons.lang3.StringUtils object AstCreatorHelper { @@ -147,21 +150,54 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As otherElements(indexOfGuardStmt).asInstanceOf[CodeBlockItemSyntax].item.asInstanceOf[GuardStmtSyntax] val elementsAfterGuard = otherElements.slice(indexOfGuardStmt + 1, otherElements.size) - val code = this.code(guardStmt) - val ifNode = controlStructureNode(guardStmt, ControlStructureTypes.IF, code) - val conditionAst = astForNode(guardStmt.conditions) + val code = this.code(guardStmt) + val ifNode = controlStructureNode(guardStmt, ControlStructureTypes.IF, code) + + // Apply optional binding desugaring for guard let + // First, create the block that will hold the unwrapped variables + val thenBlockNode = if (elementsAfterGuard.nonEmpty) blockNode(elementsAfterGuard.head) else blockNode(guardStmt) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + + val (conditionAst, unwrapAsts) = handleOptionalBindingConditions( + guardStmt.conditions.children, + onAllSimple = simpleBindings => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) + val unwraps = buildUnwrapAssignments(bindingInfos) + (condAst, unwraps) + }, + onMixed = (simpleBindings, tupleBindings) => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + (condAst, unwraps) + }, + onPartial = (simpleBindings, tupleBindings, otherConditions) => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos, otherConditions) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + (condAst, unwraps) + }, + onStandard = () => { + val condAst = astForNode(guardStmt.conditions) + (condAst, List.empty) + } + ) + + val allThenChildren = unwrapAsts ++ astsForBlockElements(elementsAfterGuard) ++ deferElementsAstsOrdered + scope.popScope() + localAstParentStack.pop() - val thenAst = astsForBlockElements(elementsAfterGuard) ++ deferElementsAstsOrdered match { + val thenAst = allThenChildren match { case Nil => - blockAst(blockNode(guardStmt), List.empty) - case blockElement :: Nil => + blockAst(thenBlockNode, List.empty) + case blockElement :: Nil if unwrapAsts.isEmpty => blockElement case blockChildren => - val block = blockNode(elementsAfterGuard.head) - blockAst(block, blockChildren) + blockAst(thenBlockNode, blockChildren) } val elseAst = astForNode(guardStmt.body) - setOrderExplicitly(elseAst, 3) val ifAst = ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) astsForBlockElements(elementsBeforeGuard) :+ ifAst @@ -516,4 +552,319 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } + /** Checks if a pattern is tuple-like (direct TuplePatternSyntax or ExpressionPatternSyntax wrapping TupleExprSyntax). + * Used to detect tuple patterns in optional binding conditions. + */ + protected def isTupleLikePattern(pattern: PatternSyntax): Boolean = pattern match { + case _: TuplePatternSyntax => true + case ep: ExpressionPatternSyntax if ep.expression.isInstanceOf[TupleExprSyntax] => true + case _ => false + } + + /** Information about an optional binding for desugaring if-let/while-let constructs. + * + * @param localName + * The name of the local variable to be created in the then/body block + * @param tmpName + * Optional temp variable name (Some if has initializer, None otherwise) + * @param binding + * The original OptionalBindingConditionSyntax node + */ + protected case class BindingInfo( + localName: String, + tmpName: Option[String], + binding: OptionalBindingConditionSyntax, + isWildcard: Boolean + ) + + /** Collects binding information from optional binding conditions for desugaring. + * + * @param bindings + * The optional binding condition nodes + * @return + * Sequence of BindingInfo with local name, optional temp name, binding node, and wildcard flag + */ + protected def collectBindingInfos(bindings: Seq[OptionalBindingConditionSyntax]): Seq[BindingInfo] = { + bindings.map { binding => + val (localName, isWildcard) = binding.pattern match { + case ident: IdentifierPatternSyntax => (code(ident.identifier), false) + case _ => (scopeLocalUniqueName("wildcard"), true) + } + val tmpName = binding.initializer.map(_ => scopeLocalUniqueName("tmp")) + BindingInfo(localName, tmpName, binding, isWildcard) + } + } + + /** Analyzes conditions to determine optional binding desugaring strategy. + * + * @param conditions + * The condition elements to analyze + * @param onAllSimple + * Handler for all simple optional bindings + * @param onMixed + * Handler for mixed simple and tuple optional bindings + * @param onPartial + * Handler for partial desugaring (simple bindings + other conditions) + * @param onStandard + * Handler for standard conditions (no simple bindings to desugar) + * @return + * The result from the appropriate handler + */ + protected def handleOptionalBindingConditions[T]( + conditions: Iterable[ConditionElementSyntax], + onAllSimple: Seq[OptionalBindingConditionSyntax] => T, + onMixed: (Seq[OptionalBindingConditionSyntax], Seq[OptionalBindingConditionSyntax]) => T, + onPartial: ( + Seq[OptionalBindingConditionSyntax], + Seq[OptionalBindingConditionSyntax], + Seq[ConditionElementSyntax] + ) => T, + onStandard: () => T + ): T = { + val conditionsSeq = conditions.toSeq + val optionalBindings = conditionsSeq.collect { + case condElem if condElem.condition.isInstanceOf[OptionalBindingConditionSyntax] => + condElem.condition.asInstanceOf[OptionalBindingConditionSyntax] + } + + val (simpleBindings, tupleBindings) = optionalBindings.partition(binding => !isTupleLikePattern(binding.pattern)) + val otherConditions = + conditionsSeq.filterNot(condElem => condElem.condition.isInstanceOf[OptionalBindingConditionSyntax]) + + if (simpleBindings.isEmpty) { + // No simple bindings to desugar + onStandard() + } else if (otherConditions.isEmpty && tupleBindings.isEmpty) { + // All conditions are simple optional bindings + onAllSimple(simpleBindings) + } else if (otherConditions.isEmpty) { + // All conditions are optional bindings (mixed simple + tuple) + onMixed(simpleBindings, tupleBindings) + } else { + // Partial: simple bindings + other conditions (and maybe tuple bindings) + onPartial(simpleBindings, tupleBindings, otherConditions) + } + } + + /** Combines multiple nil check ASTs with logical AND operator. If only one check exists, returns it directly. + * + * @param node + * The control structure node for creating operator nodes + * @param nilCheckAsts + * The nil check ASTs to combine + * @return + * Single AST representing all checks combined with && + */ + private def combineNilChecksWithAnd(node: SwiftNode, nilCheckAsts: Seq[Ast]): Ast = { + if (nilCheckAsts.size == 1) { + nilCheckAsts.head + } else { + nilCheckAsts.reduce { (left, right) => + val leftCode = left.root.map(codeOf).getOrElse("") + val rightCode = right.root.map(codeOf).getOrElse("") + val andCode = s"$leftCode && $rightCode" + val andCallNode = + createStaticCallNode(node, andCode, Operators.logicalAnd, Operators.logicalAnd, Defines.Bool) + callAst(andCallNode, List(left, right)) + } + } + } + + /** Builds the condition AST for optional binding constructs (if-let/while-let). If any binding has an initializer, + * creates a block with temp variable assignments and nil checks. Otherwise creates a simple combined nil check. + * + * @param node + * The control structure node (IfExprSyntax or WhileStmtSyntax) + * @param bindingInfos + * Information about each optional binding + * @param additionalConditions + * Additional condition elements to AND with the nil checks + * @return + * Condition AST (either block or direct nil check) + */ + protected def buildOptionalBindingCondition( + node: SwiftNode, + bindingInfos: Seq[BindingInfo], + additionalConditions: Seq[ConditionElementSyntax] = Seq.empty + ): Ast = { + val hasAnyInitializer = bindingInfos.exists(_.tmpName.isDefined) + + if (hasAnyInitializer) { + // Create condition block with assignments and checks + val condBlockNode = blockNode(node) + scope.pushNewBlockScope(condBlockNode) + localAstParentStack.push(condBlockNode) + + val assignmentAsts = bindingInfos.flatMap { info => + info.tmpName.map { tmpName => + val tmpLocalNode = localNode(info.binding, tmpName, tmpName, Defines.Any).order(0) + diffGraph.addEdge(condBlockNode, tmpLocalNode, EdgeTypes.AST) + scope.addVariable(tmpName, tmpLocalNode, Defines.Any, VariableScopeManager.ScopeType.BlockScope) + + val tmpIdentNode = identifierNode(info.binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdentNode, Defines.Any, EvaluationStrategies.BY_REFERENCE) + val initAst = astForNode(info.binding.initializer.get.value) + createAssignmentCallAst( + info.binding, + Ast(tmpIdentNode), + initAst, + s"$tmpName = ${code(info.binding.initializer.get.value)}" + ) + } + } + + // Create nil checks for all bindings + val nilCheckAsts = bindingInfos.map { info => + val (checkName, checkNode) = info.tmpName match { + case Some(tmpName) => + val tmpIdentForCheck = identifierNode(info.binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdentForCheck, Defines.Any, EvaluationStrategies.BY_REFERENCE) + (tmpName, Ast(tmpIdentForCheck)) + case None => + (info.localName, astForNode(info.binding.pattern)) + } + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"$checkName != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(checkNode, Ast(nilNode))) + } + + // Combine nil checks with additional conditions using && + val additionalConditionAsts = additionalConditions.map(condElem => astForNode(condElem.condition)) + val allChecks = nilCheckAsts ++ additionalConditionAsts + val combinedCheckAst = combineNilChecksWithAnd(node, allChecks) + + scope.popScope() + localAstParentStack.pop() + + blockAst(condBlockNode, assignmentAsts.toList :+ combinedCheckAst) + } else { + // All bindings have no initializer: create combined != nil check without block + val nilCheckAsts = bindingInfos.map { info => + val patternAst = astForNode(info.binding.pattern) + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"${info.localName} != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(patternAst, Ast(nilNode))) + } + + val additionalConditionAsts = additionalConditions.map(condElem => astForNode(condElem.condition)) + val allChecks = nilCheckAsts ++ additionalConditionAsts + combineNilChecksWithAnd(node, allChecks) + } + } + + /** Builds the body AST with optional unwrapping assignments prepended. For bindings with initializers, creates locals + * and unwrapping assignments in the body block. + * + * @param bodyNode + * The body syntax node for creating the block + * @param bodyStatements + * The statements to include in the body + * @param bindingInfos + * Information about each optional binding + * @return + * Body AST with unwrapping assignments (if needed) followed by original body statements + */ + protected def buildBodyWithUnwrapping( + bodyNode: SwiftNode, + bodyStatements: Iterable[SwiftNode], + bindingInfos: Seq[BindingInfo] + ): Ast = { + val bindingsWithInitializer = bindingInfos.filter(info => info.tmpName.isDefined && !info.isWildcard) + + if (bindingsWithInitializer.nonEmpty) { + val bodyBlockNode = blockNode(bodyNode) + scope.pushNewBlockScope(bodyBlockNode) + localAstParentStack.push(bodyBlockNode) + + val unwrapAsts = bindingsWithInitializer.map { info => + val binding = info.binding + val tmpName = info.tmpName.get + + val typeFullName = + binding.typeAnnotation.fold(Defines.Any)(typeAnn => AstCreatorHelper.cleanType(code(typeAnn.`type`))) + val kind = code(binding.bindingSpecifier) + val scopeType = + if (kind == "let") VariableScopeManager.ScopeType.BlockScope + else VariableScopeManager.ScopeType.MethodScope + + val tpeFromTypeMap = fullnameProvider.typeFullname(binding.pattern).getOrElse(typeFullName) + registerType(tpeFromTypeMap) + val localNode_ = localNode(binding, info.localName, info.localName, tpeFromTypeMap).order(0) + scope.addVariable(info.localName, localNode_, tpeFromTypeMap, scopeType) + diffGraph.addEdge(bodyBlockNode, localNode_, EdgeTypes.AST) + + val localIdentNode = identifierNode(binding, info.localName, info.localName, tpeFromTypeMap) + scope.addVariableReference(info.localName, localIdentNode, tpeFromTypeMap, EvaluationStrategies.BY_REFERENCE) + + val tmpIdent = identifierNode(binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdent, Defines.Any, EvaluationStrategies.BY_REFERENCE) + createAssignmentCallAst(binding, Ast(localIdentNode), Ast(tmpIdent), s"${info.localName} = $tmpName") + } + + val bodyAsts = bodyStatements.map(astForNode).toList + + scope.popScope() + localAstParentStack.pop() + + if (unwrapAsts.isEmpty && bodyAsts.isEmpty) { + // Empty body - return empty block without creating a nested block + Ast(bodyBlockNode) + } else { + blockAst(bodyBlockNode, unwrapAsts.toList ++ bodyAsts) + } + } else { + // All bindings have no initializer: just use original body + astForNode(bodyNode) + } + } + + /** Builds unwrap assignment ASTs for optional binding info without wrapping in a body block. Used for guard let where + * bindings need to be in parent scope. + * + * @param bindingInfos + * The binding information + * @return + * List of assignment ASTs for unwrapping + */ + protected def buildUnwrapAssignments(bindingInfos: Seq[BindingInfo]): List[Ast] = { + val bindingsWithInitializer = bindingInfos.filter(info => info.tmpName.isDefined && !info.isWildcard) + + bindingsWithInitializer.map { info => + val binding = info.binding + val tmpName = info.tmpName.get + + val typeFullName = + binding.typeAnnotation.fold(Defines.Any)(typeAnn => AstCreatorHelper.cleanType(code(typeAnn.`type`))) + val kind = code(binding.bindingSpecifier) + val scopeType = + if (kind == "let") VariableScopeManager.ScopeType.BlockScope + else VariableScopeManager.ScopeType.MethodScope + + val tpeFromTypeMap = fullnameProvider.typeFullname(binding.pattern).getOrElse(typeFullName) + registerType(tpeFromTypeMap) + val localNode_ = localNode(binding, info.localName, info.localName, tpeFromTypeMap).order(0) + scope.addVariable(info.localName, localNode_, tpeFromTypeMap, scopeType) + diffGraph.addEdge(localAstParentStack.head, localNode_, EdgeTypes.AST) + + val localIdentNode = identifierNode(binding, info.localName, info.localName, tpeFromTypeMap) + scope.addVariableReference(info.localName, localIdentNode, tpeFromTypeMap, EvaluationStrategies.BY_REFERENCE) + + val tmpIdent = identifierNode(binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdent, Defines.Any, EvaluationStrategies.BY_REFERENCE) + createAssignmentCallAst(binding, Ast(localIdentNode), Ast(tmpIdent), s"${info.localName} = $tmpName") + }.toList + } + } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index 5dfc7efa6f4b..69fdf46d3eed 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -8,7 +8,7 @@ import io.joern.x2cpg.datastructures.VariableScopeManager import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall} +import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall, NewControlStructure} import scala.annotation.{tailrec, unused} @@ -498,15 +498,114 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForIfExprSyntax(node: IfExprSyntax): Ast = { - val code = this.code(node) - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) - val conditionAstRaw = astForNode(node.conditions) - val conditionAst = conditionAstRaw.root match { - case Some(_) => conditionAstRaw - case None => blockAst(blockNode(node.conditions), List.empty) - } - val thenAst = astForNode(node.body) - val elseAst = node.elseBody.map(astForNode) + val code = this.code(node) + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) + + handleOptionalBindingConditions( + node.conditions.children, + onAllSimple = simpleBindings => astForIfLetExprSyntax(node, ifNode, simpleBindings, node.body, node.elseBody), + onMixed = (simpleBindings, tupleBindings) => + astForIfLetExprSyntaxMixed(node, ifNode, simpleBindings, tupleBindings, node.body, node.elseBody), + onPartial = (simpleBindings, tupleBindings, otherConditions) => + astForIfLetExprSyntaxPartial( + node, + ifNode, + simpleBindings, + tupleBindings, + otherConditions, + node.body, + node.elseBody + ), + onStandard = () => { + val conditionAst = astForNode(node.conditions) + val thenAst = astForNode(node.body) + val elseAst = node.elseBody.map(astForNode) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + ) + } + + /** Handles Swift optional binding (if-let) constructs. + * + * De-sugars `if let baz = foo() { body }` into: + * + * Condition: { let 0 = foo(); 0 != nil } + * + * Then block: { let baz = 0; body } + * + * For multiple bindings `if let a = foo(), let b = bar() { body }`: + * + * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * + * Then block: { a = 0; b = 1; body } + * + * For mixed cases with/without initializers `if let a = foo(), let b { body }`: + * + * Condition: { 0 = foo(); 0 != nil && b != nil } + * + * Then block: { a = 0; body } + */ + private def astForIfLetExprSyntax( + node: IfExprSyntax, + ifNode: NewControlStructure, + optionalBindings: Seq[OptionalBindingConditionSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val thenAst = buildBodyWithUnwrapping(thenBody, thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `if let a = foo(), let (b, c) = bar() { body }` into: + * + * Condition: { 0 = foo(); 0 != nil } + * + * Then block: { let a = 0; let (b, c) = bar(); body } + */ + private def astForIfLetExprSyntaxMixed( + node: IfExprSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val thenAst = buildBodyWithUnwrapping(thenBody, tupleBindings ++ thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `if let a = foo(), #unavailable(...) { body }` into: + * + * Condition: { 0 = foo(); 0 != nil && #unavailable(...) } + * + * Then block: { let a = 0; body } + */ + private def astForIfLetExprSyntaxPartial( + node: IfExprSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + val thenAst = buildBodyWithUnwrapping(thenBody, tupleBindings ++ thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) } @@ -723,7 +822,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * identifier/field-identifier nodes so the resulting AST can be safely used as an argument without node-sharing * issues. */ - protected def createFieldAccessChain(baseName: String, fields: List[String], node: SwiftNode): Ast = { + private def createFieldAccessChain(baseName: String, fields: List[String], node: SwiftNode): Ast = { val baseAst = Ast(identifierNode(node, baseName)) fields.foldLeft(baseAst) { (accAst, field) => createFieldAccessCallAst(node, accAst, fieldIdentifierNode(node, field, field)) @@ -826,7 +925,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { */ /** Creates an instanceOf check for an IsTypePatternSyntax against a subject field access. */ - protected def astForIsTypePatternInTupleContext( + private def astForIsTypePatternInTupleContext( isType: IsTypePatternSyntax, subjectAst: Ast, subjectCode: String, @@ -843,7 +942,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Creates an equality check for an expression pattern against a subject field access. */ - protected def astForExpressionPatternInTupleContext( + private def astForExpressionPatternInTupleContext( ep: ExpressionPatternSyntax, subjectAst: Ast, subjectCode: String, @@ -856,7 +955,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Creates a variable binding assignment for a pattern element against a subject field access. */ - protected def astForBindingInTupleContext( + private def astForBindingInTupleContext( varName: String, subjectAst: Ast, subjectCode: String, @@ -903,7 +1002,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Determines whether an expression inside a tuple represents a binding (let/var pattern). */ - protected def isBindingExpression(expr: ExprSyntax): Boolean = expr match { + private def isBindingExpression(expr: ExprSyntax): Boolean = expr match { case p: PatternExprSyntax => p.pattern match { case _: ValueBindingPatternSyntax => true @@ -914,7 +1013,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Dispatches a PatternSyntax inside a tuple context to the appropriate de-sugaring. */ - protected def astsForPatternInTupleContext( + private def astsForPatternInTupleContext( pattern: PatternSyntax, subjectAst: Ast, subjectCode: String, @@ -967,7 +1066,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * - `DeclReferenceExprSyntax` (`a` in `case let (a, b):`) * - `PatternExprSyntax(ValueBindingPatternSyntax(IdentifierPatternSyntax))` (`var a` in `case (var a, var b):`) */ - protected def extractBindingName(expr: ExprSyntax): String = { + private def extractBindingName(expr: ExprSyntax): String = { expr match { case d: DeclReferenceExprSyntax => code(d) case p: PatternExprSyntax => @@ -1056,7 +1155,6 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { val condIdentNode = identifierNode(node, subjectTmpName, subjectTmpName, Defines.Tuple) scope.addVariableReference(subjectTmpName, condIdentNode, Defines.Tuple, EvaluationStrategies.BY_REFERENCE) val condAst = Ast(condIdentNode) - setOrderExplicitly(condAst, 1) val switchBlockNode = blockNode(node).order(2) scope.pushNewBlockScope(switchBlockNode) @@ -1073,15 +1171,10 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { blockAst(outerBlockNode, List(subjectAssignAst, switchAstResult)) } else { - // The semantics of switch statement children is partially defined by their order value. - // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 - // because the semantics of it is already indicated via the condition edge. - val switchNode = controlStructureNode(node, ControlStructureTypes.SWITCH, code(node)) - + val switchNode = controlStructureNode(node, ControlStructureTypes.SWITCH, code(node)) val switchExpressionAst = astForNode(node.subject) - setOrderExplicitly(switchExpressionAst, 1) - val blockNode_ = blockNode(node).order(2) + val blockNode_ = blockNode(node) scope.pushNewBlockScope(blockNode_) localAstParentStack.push(blockNode_) val casesAsts = cases.flatMap(astsForSwitchCase(_, None)) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala index c5aca3b607ae..07413bf8e3d9 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala @@ -2,16 +2,11 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* import io.joern.x2cpg -import io.joern.x2cpg.Ast -import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.nodes.NewJumpLabel -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewJumpLabel} import scala.annotation.unused @@ -561,11 +556,127 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForGuardStmtSyntax(node: GuardStmtSyntax): Ast = { - val code = this.code(node) - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) - val conditionAst = astForNode(node.conditions) - val thenAst = blockAst(blockNode(node), List.empty) - val elseAst = astForNode(node.body) + val code = this.code(node) + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) + + handleOptionalBindingConditions( + node.conditions.children, + onAllSimple = simpleBindings => astForGuardLetStmtSyntax(node, ifNode, simpleBindings), + onMixed = + (simpleBindings, tupleBindings) => astForGuardLetStmtSyntaxMixed(node, ifNode, simpleBindings, tupleBindings), + onPartial = (simpleBindings, tupleBindings, otherConditions) => + astForGuardLetStmtSyntaxPartial(node, ifNode, simpleBindings, tupleBindings, otherConditions), + onStandard = () => { + val conditionAst = astForNode(node.conditions) + val thenAst = blockAst(blockNode(node), List.empty) + val elseAst = astForNode(node.body) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) + } + ) + } + + /** Handles Swift optional binding (guard-let) constructs. + * + * De-sugars `guard let x = foo() else { exit }` into: + * + * Condition: { let 0 = foo(); 0 != nil } + * + * Then block: { let x = 0 } + * + * Else block: { exit } + * + * For multiple bindings `guard let a = foo(), let b = bar() else { exit }`: + * + * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * + * Then block: { a = 0; b = 1 } + * + * For mixed cases with/without initializers `guard let a = foo(), let b else { exit }`: + * + * Condition: { 0 = foo(); 0 != nil && b != nil } + * + * Then block: { a = 0 } + */ + private def astForGuardLetStmtSyntax( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + optionalBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + + // For guard, the then block contains the unwrapped bindings in the parent scope + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `guard let a = foo(), let (b, c) = bar() else { exit }` into: + * + * Condition: { 0 = foo(); 0 != nil } + * + * Then block: { let a = 0; let (b, c) = bar() } + */ + private def astForGuardLetStmtSyntaxMixed( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `guard let a = foo(), someCondition else { exit }` into: + * + * Condition: { 0 = foo(); 0 != nil && someCondition } + * + * Then block: { let a = 0 } + */ + private def astForGuardLetStmtSyntaxPartial( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) } @@ -585,11 +696,9 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { private def astForRepeatStmtSyntax(node: RepeatStmtSyntax): Ast = { val code = this.code(node) // In Swift, a repeat-while loop is semantically the same as a C do-while loop - val doNode = controlStructureNode(node, ControlStructureTypes.DO, code) - val conditionAst = astForNode(node.condition) - val bodyAst = astForNode(node.body) - setOrderExplicitly(conditionAst, 1) - setOrderExplicitly(bodyAst, 2) + val doNode = controlStructureNode(node, ControlStructureTypes.DO, code) + val conditionAst = astForNode(node.condition) + val bodyAst = astForNode(node.body) val astWithChildren = controlStructureAst(doNode, Option(conditionAst), Seq(bodyAst), placeConditionLast = true) bodyAst.root match { case Some(bodyRoot) => astWithChildren.withDoBodyEdge(doNode, bodyRoot) @@ -624,13 +733,113 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForWhileStmtSyntax(node: WhileStmtSyntax): Ast = { - val code = this.code(node) - val conditionAst = astForNode(node.conditions) - val bodyAst = astForNode(node.body) + val code = this.code(node) + + handleOptionalBindingConditions( + node.conditions.children, + onAllSimple = simpleBindings => astForWhileLetStmtSyntax(node, simpleBindings), + onMixed = (simpleBindings, tupleBindings) => astForWhileLetStmtSyntaxMixed(node, simpleBindings, tupleBindings), + onPartial = (simpleBindings, tupleBindings, otherConditions) => + astForWhileLetStmtSyntaxPartial(node, simpleBindings, tupleBindings, otherConditions), + onStandard = () => { + val conditionAst = astForNode(node.conditions) + val bodyAst = astForNode(node.body) + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code), + lineNumber = line(node), + columnNumber = column(node) + ) + } + ) + } + + /** Handles Swift optional binding (while-let) constructs. + * + * De-sugars `while let item = iterator.next() { body }` into: + * + * Condition: { let 0 = iterator.next(); 0 != nil } + * + * Loop body: { let item = 0; body } + * + * For multiple bindings `while let a = foo(), let b = bar() { body }`: + * + * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * + * Loop body: { a = 0; b = 1; body } + * + * For mixed cases with/without initializers `while let a = foo(), let b { body }`: + * + * Condition: { 0 = foo(); 0 != nil && b != nil } + * + * Loop body: { a = 0; body } + */ + private def astForWhileLetStmtSyntax( + node: WhileStmtSyntax, + optionalBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val bodyAst = buildBodyWithUnwrapping(node.body, node.body.statements.children, bindingInfos) + + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code(node)), + lineNumber = line(node), + columnNumber = column(node) + ) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `while let a = foo(), let (b, c) = bar() { body }` into: + * + * Condition: { 0 = foo(); 0 != nil } + * + * Loop body: { let a = 0; let (b, c) = bar(); body } + */ + private def astForWhileLetStmtSyntaxMixed( + node: WhileStmtSyntax, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val bodyAst = buildBodyWithUnwrapping(node.body, tupleBindings ++ node.body.statements.children, bindingInfos) + whileAst( Option(conditionAst), Seq(bodyAst), - code = Option(code), + code = Option(code(node)), + lineNumber = line(node), + columnNumber = column(node) + ) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `while let a = foo(), someCondition { body }` into: + * + * Condition: { 0 = foo(); 0 != nil && someCondition } + * + * Loop body: { let a = 0; body } + */ + private def astForWhileLetStmtSyntaxPartial( + node: WhileStmtSyntax, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + val bodyAst = buildBodyWithUnwrapping(node.body, tupleBindings ++ node.body.statements.children, bindingInfos) + + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code(node)), lineNumber = line(node), columnNumber = column(node) ) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala index 41748c633180..08336eaf8827 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala @@ -1,7 +1,6 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* -import io.joern.x2cpg import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* @@ -9,10 +8,6 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Opera trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def setOrderExplicitly(ast: Ast, order: Int): Unit = { - ast.root.foreach { case expr: ExpressionNew => expr.order = order } - } - protected def codeOf(node: NewNode): String = node match { case astNodeNew: AstNodeNew => astNodeNew.code case _ => "" diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index 854266e6f441..139d8ae78e81 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -43,11 +43,22 @@ class AvailabilityQueryTests extends SwiftSrc2CpgSuite { val cpg = code("if let _ = Optional(5), #unavailable(OSX 10.52, *) {}") val List(ifControlStructure) = cpg.controlStructure.isIf.l ifControlStructure.whenTrue.astChildren shouldBe empty - val List(assignment, unavailable) = ifControlStructure.condition.astChildren.isCall.l - assignment.code shouldBe "let _ = Optional(5)" - assignment.name shouldBe Operators.assignment + + // Condition is desugared: { 0 = Optional(5); 0 != nil && #unavailable(...) } + val List(condBlock) = ifControlStructure.condition.isBlock.l + val List(tmpLocal) = condBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + + val List(assignment) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assignment.code shouldBe s"$tmpName = Optional(5)" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(nilCheck, unavailable) = andCheck.argument.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"$tmpName != nil" unavailable.code shouldBe "#unavailable(OSX 10.52, *)" unavailable.name shouldBe "#unavailable" + ifControlStructure.whenFalse shouldBe empty } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index 52f99a6f5309..b8a5ee65c4eb 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -3,9 +3,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite - import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class GuardTests extends SwiftSrc2CpgSuite { @@ -116,7 +114,8 @@ class GuardTests extends SwiftSrc2CpgSuite { |} |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkAge").block.l - methodBlock.local.name.l shouldBe List("myAge", "age") + // After desugaring: age in method block, 0 in condition block, myAge in then block + methodBlock.local.name.l should contain("age") val List(call) = methodBlock.astChildren.isCall.l call.order shouldBe 1 call.code shouldBe "var age: Int? = 22" @@ -124,12 +123,15 @@ class GuardTests extends SwiftSrc2CpgSuite { guardIf.order shouldBe 2 guardIf.code should startWith("guard let myAge = age") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("let myAge = age") - guardIf.whenTrue.code.l shouldBe List("print(\"My age is \\(myAge)\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Age is undefined\")", - "return" - ) + // Condition is desugared to block with temp assignment and nil check + val condBlock = guardIf.condition.isBlock.l + condBlock should not be empty + condBlock.head.local.name.l.exists(_.startsWith("")) shouldBe true + // Then block should have myAge local + val thenBlock = guardIf.whenTrue.isBlock.l + thenBlock should not be empty + thenBlock.head.local.name.l should contain("myAge") + guardIf.whenFalse.astChildren.code.l shouldBe List("print(\"Age is undefined\")", "return") } "testGuard6" in { @@ -208,6 +210,188 @@ class GuardTests extends SwiftSrc2CpgSuite { guardIfB.whenFalse.astChildren.code.l shouldBe List("print(\"else b\")", "return") } + "testGuardLet" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | guard let value = optionalValue else { + | return + | } + | print(value) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { 0 = optionalValue; 0 != nil } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(tmpLocal) = condBlock.astChildren.isLocal.l + tmpLocal.name should startWith("") + val tmpName = tmpLocal.name + + val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + condAssign.code shouldBe s"$tmpName = optionalValue" + condAssign.argument(1).code shouldBe tmpName + condAssign.argument(2).code shouldBe "optionalValue" + + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"$tmpName != nil" + condCheck.argument(1).code shouldBe tmpName + condCheck.argument(2).code shouldBe "nil" + + // Then block: { let value = 0 } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(valueLocal) = thenBlock.astChildren.isLocal.l + valueLocal.name shouldBe "value" + + val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + thenAssign.code shouldBe s"value = $tmpName" + thenAssign.argument(1).code shouldBe "value" + thenAssign.argument(2).code shouldBe tmpName + } + + "testGuardLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | guard let optionalValue else { + | return + | } + | print(optionalValue) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // For "guard let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = guardIf.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Then branch: empty block (no new local or assignment for optionalValue) + val thenNodes = guardIf.whenTrue.l + thenNodes should not be empty + + // Verify no new local named "optionalValue" was created in the then branch + val localsInThen = thenNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) + localsInThen shouldBe empty + + // Verify no assignment to optionalValue in the then branch + val assignmentsInThen = thenNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) + val optionalValueAssignments = assignmentsInThen.filter(_.argument(1).code == "optionalValue") + optionalValueAssignments shouldBe empty + } + + "testGuardLetMultipleBindings" in { + val cpg = code(""" + |func test() { + | guard let a = foo(), let b = bar() else { + | return + | } + | print(a, b) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(tmp0, tmp1) = condBlock.astChildren.isLocal.l + tmp0.name shouldBe "0" + tmp1.name shouldBe "1" + + val List(tmp0Assign, tmp1Assign) = condBlock.astChildren.assignment.l + tmp0Assign.code shouldBe s"${tmp0.name} = foo()" + tmp1Assign.code shouldBe s"${tmp1.name} = bar()" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + tmp0Check.code shouldBe s"${tmp0.name} != nil" + tmp1Check.code shouldBe s"${tmp1.name} != nil" + + // Then block: { let a = 0; let b = 1 } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + thenBlock.astChildren.isLocal.name.sorted shouldBe List("a", "b") + + val List(aAssignment, bAssignment) = thenBlock.astChildren.assignment.l + aAssignment.code shouldBe s"a = ${tmp0.name}" + bAssignment.code shouldBe s"b = ${tmp1.name}" + } + + "testGuardLetMixedWithAndWithoutInitializer" in { + val cpg = code(""" + |func test(existing: Int?) { + | guard let a = foo(), let existing else { + | return + | } + | print(a, existing) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { 0 = foo(); 0 != nil && existing != nil } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(tmpLocal) = condBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + tmpName should startWith("") + + val List(tmpAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + tmpAssign.code shouldBe s"$tmpName = foo()" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(tmpCheck) = andCheck.arguments(1).isCall.nameExact(Operators.notEquals).l + val List(existingCheck) = andCheck.arguments(2).isCall.nameExact(Operators.notEquals).l + tmpCheck.code shouldBe s"$tmpName != nil" + existingCheck.code shouldBe "existing != nil" + + // Then block: { let a = 0 } (no assignment for 'existing') + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + thenAssign.code shouldBe s"a = $tmpName" + } + + "testGuardLetWithOtherConditions" in { + val cpg = code(""" + |func test(flag: Bool) { + | guard let a = foo(), flag else { + | return + | } + | print(a) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { 0 = foo(); 0 != nil && flag } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(tmpLocal) = condBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val arguments = andCheck.argument.l + arguments should have size 2 + + // One should be the nil check, the other should be the flag identifier + val nilChecks = + arguments.collect { case c if c.isCall => c }.flatMap(_.ast.isCall.nameExact(Operators.notEquals).l) + val flags = arguments.collect { case i if i.isIdentifier => i }.flatMap(_.ast.isIdentifier.nameExact("flag").l) + + nilChecks.code.l should contain(s"$tmpName != nil") + flags should not be empty + } + } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala index f7d0ec63bbef..4b4916f57aea 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala @@ -18,12 +18,21 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { val List(globalBlock) = cpg.method.nameExact("").block.l val List(localA) = globalBlock.local.nameExact("a").l localA.typeFullName shouldBe "Swift.Int" - val List(localB) = globalBlock.local.nameExact("b").l - localB.typeFullName shouldBe "ANY" - val assigns = cpg.call.nameExact(Operators.assignment).code.l - assigns shouldBe List("let a: Int? = 1", "let b = a") + + // After desugaring, b is in the guard's then block, not the global block val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l guardIf.code should startWith("guard let b = a else") + + // Check that desugaring created the temp variable and nil check in condition + val condBlock = guardIf.condition.isBlock.l + condBlock should not be empty + condBlock.head.local.name.l.exists(_.startsWith("")) shouldBe true + + // Check that b local is in the then block + val thenBlock = guardIf.whenTrue.isBlock.l + thenBlock should not be empty + val List(localB) = thenBlock.head.local.nameExact("b").l + localB.typeFullName shouldBe "ANY" } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala index ce6d1a5ca332..820f1308248e 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala @@ -1,24 +1,51 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite - import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class StatementTests extends SwiftSrc2CpgSuite { "StatementTests" should { - "testIf" in { + "testIfLet" in { val cpg = code(""" - |if let baz {} - |if let self = self {} + |func test() { + | if let baz = optionalValue { + | print(baz) + | } + |} |""".stripMargin) - val ifs = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).code.l - ifs shouldBe List("if let baz {}", "if let self = self {}") - cpg.local.name.l should contain allOf ("baz", "self") - cpg.call.codeExact("let self = self").l should not be empty + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { 0 = optionalValue; 0 != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(tmpLocal) = condBlock.astChildren.isLocal.l + tmpLocal.name should startWith("") + val tmpName = tmpLocal.name + + val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + condAssign.code shouldBe s"$tmpName = optionalValue" + condAssign.argument(1).code shouldBe tmpName + condAssign.argument(2).code shouldBe "optionalValue" + + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"$tmpName != nil" + condCheck.argument(1).code shouldBe tmpName + condCheck.argument(2).code shouldBe "nil" + + // Then block: { let baz = 0; print(baz) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(bazLocal) = thenBlock.astChildren.isLocal.l + bazLocal.name shouldBe "baz" + + val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + thenAssign.code shouldBe s"baz = $tmpName" + thenAssign.argument(1).code shouldBe "baz" + thenAssign.argument(2).code shouldBe tmpName } "testDoCatch" in { @@ -120,6 +147,322 @@ class StatementTests extends SwiftSrc2CpgSuite { ifNode.code should startWith("if true") } + "testIfLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | if let optionalValue { + | print(optionalValue) + | } + |} + |""".stripMargin) + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // For "if let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = ifNode.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Then branch: just the original body, no new local or assignment for optionalValue + // The body might be a block or another structure depending on CodeBlockSyntax handling + val thenNodes = ifNode.whenTrue.l + thenNodes should not be empty + + // Verify no new local named "optionalValue" was created in the then branch + val localsInThen = thenNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) + localsInThen shouldBe empty + + // Verify no assignment to optionalValue in the then branch + val assignmentsInThen = thenNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) + val optionalValueAssignments = assignmentsInThen.filter(_.argument(1).code == "optionalValue") + optionalValueAssignments shouldBe empty + } + + "testWhileLet" in { + val cpg = code(""" + |func test() { + | while let item = iterator.next() { + | print(item) + | } + |} + |""".stripMargin) + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: desugared to { 0 = iterator.next(); 0 != nil } + val List(condBlock) = whileNode.condition.isBlock.l + + val List(tmpLocal) = condBlock.astChildren.isLocal.l + tmpLocal.name should startWith("") + val tmpName = tmpLocal.name + + val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + condAssign.code shouldBe s"$tmpName = iterator.next()" + condAssign.argument(1).code shouldBe tmpName + condAssign.argument(2).code shouldBe "iterator.next()" + + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"$tmpName != nil" + condCheck.argument(1).code shouldBe tmpName + condCheck.argument(2).code shouldBe "nil" + + // Loop body: { let item = 0; print(item) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + val List(itemLocal) = bodyBlock.astChildren.isLocal.l + itemLocal.name shouldBe "item" + + val List(bodyAssign) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).l + bodyAssign.code shouldBe s"item = $tmpName" + bodyAssign.argument(1).code shouldBe "item" + bodyAssign.argument(2).code shouldBe tmpName + } + + "testWhileLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | while let optionalValue { + | print(optionalValue) + | } + |} + |""".stripMargin) + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // For "while let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = whileNode.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Loop body: just the original body, no new local or assignment for optionalValue + val bodyNodes = whileNode.whenTrue.l + bodyNodes should not be empty + + // Verify no new local named "optionalValue" was created in the loop body + val localsInBody = bodyNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) + localsInBody shouldBe empty + + // Verify no assignment to optionalValue in the loop body + val assignmentsInBody = bodyNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) + val optionalValueAssignments = assignmentsInBody.filter(_.argument(1).code == "optionalValue") + optionalValueAssignments shouldBe empty + } + + "testIfLetMultiple" in { + val cpg = code(""" + |func test() { + | if let a = Optional(1), let b = Optional(2) { + | print(a, b) + | } + |} + |""".stripMargin) + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { 0 = Optional(1); 1 = Optional(2); 0 != nil && 1 != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l + tmp1Local.name should startWith("") + tmp2Local.name should startWith("") + val tmp1Name = tmp1Local.name + val tmp2Name = tmp2Local.name + + val List(assign1, assign2) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assign1.code shouldBe s"$tmp1Name = Optional(1)" + assign2.code shouldBe s"$tmp2Name = Optional(2)" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"$tmp1Name != nil" + check2.code shouldBe s"$tmp2Name != nil" + + // Then block: { a = 0; b = 1; print(a, b) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(aLocal, bLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + + val List(unwrapA, unwrapB) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp1Name" + unwrapB.code shouldBe s"b = $tmp2Name" + } + + "testIfLetMixed" in { + val cpg = code(""" + |func test(opt2: Int?) { + | if let a = Optional(1), let opt2 { + | print(a, opt2) + | } + |} + |""".stripMargin) + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { 0 = Optional(1); 0 != nil && opt2 != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(tmp1Local) = condBlock.astChildren.isLocal.l + val tmp1Name = tmp1Local.name + + val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assign1.code shouldBe s"$tmp1Name = Optional(1)" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"$tmp1Name != nil" + check2.code shouldBe "opt2 != nil" + + // Then block: { a = 0; print(a, opt2) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp1Name" + } + + "testWhileLetMultiple" in { + val cpg = code(""" + |func test() { + | while let a = iterator1.next(), let b = iterator2.next() { + | print(a, b) + | } + |} + |""".stripMargin) + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: { 0 = iterator1.next(); 1 = iterator2.next(); 0 != nil && 1 != nil } + val List(condBlock) = whileNode.condition.isBlock.l + + val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l + val tmp1Name = tmp1Local.name + val tmp2Name = tmp2Local.name + + val List(assign1, assign2) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assign1.code shouldBe s"$tmp1Name = iterator1.next()" + assign2.code shouldBe s"$tmp2Name = iterator2.next()" + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"$tmp1Name != nil" + check2.code shouldBe s"$tmp2Name != nil" + + // Loop body: { a = 0; b = 1; print(a, b) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + val List(aLocal, bLocal) = bodyBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + + val List(unwrapA, unwrapB) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp1Name" + unwrapB.code shouldBe s"b = $tmp2Name" + } + + "testIfLetMixedWithTuplePattern" in { + val cpg = code(""" + |func test() { + | if let a = foo(), let (b, c) = bar() { + | print(a, b, c) + | } + |} + |""".stripMargin) + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { 0 = foo(); 0 != nil } (tuple pattern excluded from condition) + val List(condBlock) = ifNode.condition.isBlock.l + + val List(tmp1Local) = condBlock.astChildren.isLocal.l + val tmp1Name = tmp1Local.name + + val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assign1.code shouldBe s"$tmp1Name = foo()" + assign1.argument(1).code shouldBe tmp1Name + assign1.argument(2).code shouldBe "foo()" + + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"$tmp1Name != nil" + check1.argument(1).code shouldBe tmp1Name + check1.argument(2).code shouldBe "nil" + + // Then block: { a = 0; let (b, c) = bar(); print(a, b, c) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + // First child: unwrapping assignment for 'a' + val List(aLocal) = thenBlock.astChildren.isLocal.nameExact("a").l + aLocal.name shouldBe "a" + + val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp1Name").l + unwrapA.argument(1).code shouldBe "a" + unwrapA.argument(2).code shouldBe tmp1Name + + // Second child: tuple binding block for (b, c) = bar() + // The tuple binding creates a nested block with tmp assignment + tuple destructuring + val List(tupleBindingBlock) = thenBlock.astChildren.isBlock.l + + // Inside the tuple binding block, there should be an assignment involving bar() + val List(barAssignment) = tupleBindingBlock.astChildren.isCall.nameExact(Operators.assignment).code(".*bar.*").l + barAssignment.code should include("bar()") + } + + "testWhileLetMixedWithTuplePattern" in { + val cpg = code(""" + |func test() { + | while let a = foo(), let (b, c) = bar() { + | print(a, b, c) + | } + |} + |""".stripMargin) + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: { 0 = foo(); 0 != nil } (tuple pattern excluded from condition) + val List(condBlock) = whileNode.condition.isBlock.l + + val List(tmp1Local) = condBlock.astChildren.isLocal.l + val tmp1Name = tmp1Local.name + + val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + assign1.code shouldBe s"$tmp1Name = foo()" + assign1.argument(1).code shouldBe tmp1Name + assign1.argument(2).code shouldBe "foo()" + + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"$tmp1Name != nil" + check1.argument(1).code shouldBe tmp1Name + check1.argument(2).code shouldBe "nil" + + // Loop body: { a = 0; let (b, c) = bar(); print(a, b, c) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + // First child: unwrapping assignment for 'a' + val List(aLocal) = bodyBlock.astChildren.isLocal.nameExact("a").l + aLocal.name shouldBe "a" + + val List(unwrapA) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp1Name").l + unwrapA.argument(1).code shouldBe "a" + unwrapA.argument(2).code shouldBe tmp1Name + + // Second child: tuple binding block for (b, c) = bar() + // The tuple binding creates a nested block with tmp assignment + tuple destructuring + val List(tupleBindingBlock) = bodyBlock.astChildren.isBlock.l + + // Inside the tuple binding block, there should be an assignment involving bar() + val List(barAssignment) = tupleBindingBlock.astChildren.isCall.nameExact(Operators.assignment).code(".*bar.*").l + barAssignment.code should include("bar()") + } + } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala index daeed30abc40..af636bc19b1b 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala @@ -68,7 +68,8 @@ class WhileTests extends SwiftSrc2CpgSuite { inside(controlStruct.condition.l) { case List(cndNode: Literal) => cndNode.code shouldBe "true" } - controlStruct.whenTrue.code.l shouldBe List("""print("Endless Loop")""") + controlStruct.condition.code.l shouldBe List("true") + controlStruct.doBodyOut.code.l shouldBe List("""print("Endless Loop")""") controlStruct.lineNumber shouldBe Some(2) controlStruct.columnNumber shouldBe Some(1) } @@ -94,7 +95,8 @@ class WhileTests extends SwiftSrc2CpgSuite { n.name shouldBe "n" n.order shouldBe 2 } - controlStruct.whenTrue.astChildren.code.l shouldBe List("print(i)", "i = i + 1") + controlStruct.condition.code.l shouldBe List("i <= n") + controlStruct.doBodyOut.astChildren.code.l shouldBe List("print(i)", "i = i + 1") controlStruct.lineNumber shouldBe Some(2) controlStruct.columnNumber shouldBe Some(1) } From 0a6564064795e9d0702a1428ee3a2fe832285022 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Thu, 28 May 2026 15:20:04 +0200 Subject: [PATCH 2/5] Addressed review feedback --- .../astcreation/AstCreatorHelper.scala | 108 +++++++----------- .../astcreation/AstForExprSyntaxCreator.scala | 6 +- .../passes/ast/AvailabilityQueryTests.scala | 10 +- .../swiftsrc2cpg/passes/ast/GuardTests.scala | 67 +++++------ .../passes/ast/GuardTopLevelTests.scala | 20 ++-- .../passes/ast/StatementTests.scala | 86 +++++++------- 6 files changed, 130 insertions(+), 167 deletions(-) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index 3af9bb4e5a0b..b1ce167719c2 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -154,7 +154,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val ifNode = controlStructureNode(guardStmt, ControlStructureTypes.IF, code) // Apply optional binding desugaring for guard let - // First, create the block that will hold the unwrapped variables + // Create the block that will hold the unwrapped variables (blockNode argument is only used for location info) val thenBlockNode = if (elementsAfterGuard.nonEmpty) blockNode(elementsAfterGuard.head) else blockNode(guardStmt) scope.pushNewBlockScope(thenBlockNode) localAstParentStack.push(thenBlockNode) @@ -562,13 +562,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } /** Information about an optional binding for desugaring if-let/while-let constructs. + * + * Contains the CPG-level names for optional binding desugaring. The localName is the unwrapped variable name that + * appears in the then/body block. The tmpName (when present) holds the optional value from the initializer and is + * used for the nil check in the condition block, ensuring single evaluation of the initializer expression. * * @param localName - * The name of the local variable to be created in the then/body block + * CPG name for the unwrapped variable in the then/body block (e.g., "a" in `if let a = foo()`) * @param tmpName - * Optional temp variable name (Some if has initializer, None otherwise) + * CPG name for temporary holding the optional value in condition (e.g., "0" in the nil check) * @param binding - * The original OptionalBindingConditionSyntax node + * The source-level OptionalBindingConditionSyntax node + * @param isWildcard + * True if the pattern is a wildcard (e.g., `if let _ = foo()`), requiring generated name */ protected case class BindingInfo( localName: String, @@ -577,13 +583,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As isWildcard: Boolean ) - /** Collects binding information from optional binding conditions for desugaring. - * - * @param bindings - * The optional binding condition nodes - * @return - * Sequence of BindingInfo with local name, optional temp name, binding node, and wildcard flag - */ protected def collectBindingInfos(bindings: Seq[OptionalBindingConditionSyntax]): Seq[BindingInfo] = { bindings.map { binding => val (localName, isWildcard) = binding.pattern match { @@ -595,21 +594,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } - /** Analyzes conditions to determine optional binding desugaring strategy. - * - * @param conditions - * The condition elements to analyze - * @param onAllSimple - * Handler for all simple optional bindings - * @param onMixed - * Handler for mixed simple and tuple optional bindings - * @param onPartial - * Handler for partial desugaring (simple bindings + other conditions) - * @param onStandard - * Handler for standard conditions (no simple bindings to desugar) - * @return - * The result from the appropriate handler - */ protected def handleOptionalBindingConditions[T]( conditions: Iterable[ConditionElementSyntax], onAllSimple: Seq[OptionalBindingConditionSyntax] => T, @@ -690,51 +674,57 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val hasAnyInitializer = bindingInfos.exists(_.tmpName.isDefined) if (hasAnyInitializer) { - // Create condition block with assignments and checks val condBlockNode = blockNode(node) scope.pushNewBlockScope(condBlockNode) localAstParentStack.push(condBlockNode) - val assignmentAsts = bindingInfos.flatMap { info => - info.tmpName.map { tmpName => + bindingInfos.foreach { info => + info.tmpName.foreach { tmpName => val tmpLocalNode = localNode(info.binding, tmpName, tmpName, Defines.Any).order(0) diffGraph.addEdge(condBlockNode, tmpLocalNode, EdgeTypes.AST) scope.addVariable(tmpName, tmpLocalNode, Defines.Any, VariableScopeManager.ScopeType.BlockScope) - - val tmpIdentNode = identifierNode(info.binding, tmpName, tmpName, Defines.Any) - scope.addVariableReference(tmpName, tmpIdentNode, Defines.Any, EvaluationStrategies.BY_REFERENCE) - val initAst = astForNode(info.binding.initializer.get.value) - createAssignmentCallAst( - info.binding, - Ast(tmpIdentNode), - initAst, - s"$tmpName = ${code(info.binding.initializer.get.value)}" - ) } } - // Create nil checks for all bindings val nilCheckAsts = bindingInfos.map { info => - val (checkName, checkNode) = info.tmpName match { + info.tmpName match { case Some(tmpName) => + val tmpIdentNode = identifierNode(info.binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdentNode, Defines.Any, EvaluationStrategies.BY_REFERENCE) + val initAst = astForNode(info.binding.initializer.get.value) + val assignAst = createAssignmentCallAst( + info.binding, + Ast(tmpIdentNode), + initAst, + s"$tmpName = ${code(info.binding.initializer.get.value)}" + ) + val tmpIdentForCheck = identifierNode(info.binding, tmpName, tmpName, Defines.Any) scope.addVariableReference(tmpName, tmpIdentForCheck, Defines.Any, EvaluationStrategies.BY_REFERENCE) - (tmpName, Ast(tmpIdentForCheck)) + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"($tmpName = ${code(info.binding.initializer.get.value)}) != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(assignAst, Ast(nilNode))) + case None => - (info.localName, astForNode(info.binding.pattern)) + val patternAst = astForNode(info.binding.pattern) + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"${info.localName} != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(patternAst, Ast(nilNode))) } - val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) - val checkCallNode = createStaticCallNode( - info.binding, - s"$checkName != nil", - Operators.notEquals, - Operators.notEquals, - Defines.Bool - ) - callAst(checkCallNode, List(checkNode, Ast(nilNode))) } - // Combine nil checks with additional conditions using && val additionalConditionAsts = additionalConditions.map(condElem => astForNode(condElem.condition)) val allChecks = nilCheckAsts ++ additionalConditionAsts val combinedCheckAst = combineNilChecksWithAnd(node, allChecks) @@ -742,9 +732,8 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As scope.popScope() localAstParentStack.pop() - blockAst(condBlockNode, assignmentAsts.toList :+ combinedCheckAst) + blockAst(condBlockNode, List(combinedCheckAst)) } else { - // All bindings have no initializer: create combined != nil check without block val nilCheckAsts = bindingInfos.map { info => val patternAst = astForNode(info.binding.pattern) val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) @@ -825,19 +814,10 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As blockAst(bodyBlockNode, unwrapAsts.toList ++ bodyAsts) } } else { - // All bindings have no initializer: just use original body astForNode(bodyNode) } } - /** Builds unwrap assignment ASTs for optional binding info without wrapping in a body block. Used for guard let where - * bindings need to be in parent scope. - * - * @param bindingInfos - * The binding information - * @return - * List of assignment ASTs for unwrapping - */ protected def buildUnwrapAssignments(bindingInfos: Seq[BindingInfo]): List[Ast] = { val bindingsWithInitializer = bindingInfos.filter(info => info.tmpName.isDefined && !info.isWildcard) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index 69fdf46d3eed..cc39a977a048 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -529,19 +529,19 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `if let baz = foo() { body }` into: * - * Condition: { let 0 = foo(); 0 != nil } + * Condition: { let 0; (0 = foo()) != nil } * * Then block: { let baz = 0; body } * * For multiple bindings `if let a = foo(), let b = bar() { body }`: * - * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * Condition: { let 0; let 1; (0 = foo()) != nil && (1 = bar()) != nil } * * Then block: { a = 0; b = 1; body } * * For mixed cases with/without initializers `if let a = foo(), let b { body }`: * - * Condition: { 0 = foo(); 0 != nil && b != nil } + * Condition: { let 0; (0 = foo()) != nil && b != nil } * * Then block: { a = 0; body } */ diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index 139d8ae78e81..3c9a4b345d9c 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -44,21 +44,21 @@ class AvailabilityQueryTests extends SwiftSrc2CpgSuite { val List(ifControlStructure) = cpg.controlStructure.isIf.l ifControlStructure.whenTrue.astChildren shouldBe empty - // Condition is desugared: { 0 = Optional(5); 0 != nil && #unavailable(...) } + // Condition is desugared: { let 0; (0 = Optional(5)) != nil && #unavailable(...) } val List(condBlock) = ifControlStructure.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l val tmpName = tmpLocal.name - val List(assignment) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l - assignment.code shouldBe s"$tmpName = Optional(5)" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(nilCheck, unavailable) = andCheck.argument.isCall.l nilCheck.name shouldBe Operators.notEquals - nilCheck.code shouldBe s"$tmpName != nil" + nilCheck.code shouldBe s"($tmpName = Optional(5)) != nil" unavailable.code shouldBe "#unavailable(OSX 10.52, *)" unavailable.name shouldBe "#unavailable" + val List(assignment) = nilCheck.argument.assignment.l + assignment.code shouldBe s"$tmpName = Optional(5)" + ifControlStructure.whenFalse shouldBe empty } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index b8a5ee65c4eb..d448e49712eb 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -4,6 +4,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class GuardTests extends SwiftSrc2CpgSuite { @@ -222,23 +223,21 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: desugared to { 0 = optionalValue; 0 != nil } + // Condition: desugared to { let 0; (0 = optionalValue) != nil } val List(condBlock) = guardIf.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l tmpLocal.name should startWith("") val tmpName = tmpLocal.name - val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"($tmpName = optionalValue) != nil" + + val List(condAssign) = condCheck.argument.isCall.nameExact(Operators.assignment).l condAssign.code shouldBe s"$tmpName = optionalValue" condAssign.argument(1).code shouldBe tmpName condAssign.argument(2).code shouldBe "optionalValue" - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - condCheck.code shouldBe s"$tmpName != nil" - condCheck.argument(1).code shouldBe tmpName - condCheck.argument(2).code shouldBe "nil" - // Then block: { let value = 0 } val List(thenBlock) = guardIf.whenTrue.isBlock.l @@ -271,18 +270,12 @@ class GuardTests extends SwiftSrc2CpgSuite { condCheck.argument(1).code shouldBe "optionalValue" condCheck.argument(2).code shouldBe "nil" - // Then branch: empty block (no new local or assignment for optionalValue) - val thenNodes = guardIf.whenTrue.l - thenNodes should not be empty - - // Verify no new local named "optionalValue" was created in the then branch - val localsInThen = thenNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) - localsInThen shouldBe empty - - // Verify no assignment to optionalValue in the then branch - val assignmentsInThen = thenNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) - val optionalValueAssignments = assignmentsInThen.filter(_.argument(1).code == "optionalValue") - optionalValueAssignments shouldBe empty + // Then branch: print(optionalValue) but no new local or assignment for optionalValue + inside(guardIf.whenTrue.l) { case List(thenBlock) => + thenBlock.astChildren.isLocal.nameExact("optionalValue").l shouldBe empty + val assignments = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + assignments.filter(_.argument(1).code == "optionalValue") shouldBe empty + } } "testGuardLetMultipleBindings" in { @@ -304,14 +297,10 @@ class GuardTests extends SwiftSrc2CpgSuite { tmp0.name shouldBe "0" tmp1.name shouldBe "1" - val List(tmp0Assign, tmp1Assign) = condBlock.astChildren.assignment.l - tmp0Assign.code shouldBe s"${tmp0.name} = foo()" - tmp1Assign.code shouldBe s"${tmp1.name} = bar()" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - tmp0Check.code shouldBe s"${tmp0.name} != nil" - tmp1Check.code shouldBe s"${tmp1.name} != nil" + tmp0Check.code shouldBe s"(${tmp0.name} = foo()) != nil" + tmp1Check.code shouldBe s"(${tmp1.name} = bar()) != nil" // Then block: { let a = 0; let b = 1 } val List(thenBlock) = guardIf.whenTrue.isBlock.l @@ -335,22 +324,22 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { 0 = foo(); 0 != nil && existing != nil } + // Condition: { let 0; (0 = foo()) != nil && existing != nil } val List(condBlock) = guardIf.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l val tmpName = tmpLocal.name tmpName should startWith("") - val List(tmpAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l - tmpAssign.code shouldBe s"$tmpName = foo()" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(tmpCheck) = andCheck.arguments(1).isCall.nameExact(Operators.notEquals).l val List(existingCheck) = andCheck.arguments(2).isCall.nameExact(Operators.notEquals).l - tmpCheck.code shouldBe s"$tmpName != nil" + tmpCheck.code shouldBe s"($tmpName = foo()) != nil" existingCheck.code shouldBe "existing != nil" + val List(tmpAssign) = tmpCheck.argument.isCall.nameExact(Operators.assignment).l + tmpAssign.code shouldBe s"$tmpName = foo()" + // Then block: { let a = 0 } (no assignment for 'existing') val List(thenBlock) = guardIf.whenTrue.isBlock.l @@ -380,16 +369,14 @@ class GuardTests extends SwiftSrc2CpgSuite { val tmpName = tmpLocal.name val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l - val arguments = andCheck.argument.l - arguments should have size 2 - - // One should be the nil check, the other should be the flag identifier - val nilChecks = - arguments.collect { case c if c.isCall => c }.flatMap(_.ast.isCall.nameExact(Operators.notEquals).l) - val flags = arguments.collect { case i if i.isIdentifier => i }.flatMap(_.ast.isIdentifier.nameExact("flag").l) - - nilChecks.code.l should contain(s"$tmpName != nil") - flags should not be empty + inside(andCheck.argument.l) { + case List(nilCheck: Call, flag: Identifier) => + nilCheck.code shouldBe s"($tmpName = foo()) != nil" + flag.name shouldBe "flag" + case List(flag: Identifier, nilCheck: Call) => + nilCheck.code shouldBe s"($tmpName = foo()) != nil" + flag.name shouldBe "flag" + } } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala index 4b4916f57aea..b3cd67e47ca2 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala @@ -14,6 +14,7 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { val cpg = code(""" |let a: Int? = 1 |guard let b = a else {} + |print(b) |""".stripMargin) val List(globalBlock) = cpg.method.nameExact("").block.l val List(localA) = globalBlock.local.nameExact("a").l @@ -24,15 +25,18 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { guardIf.code should startWith("guard let b = a else") // Check that desugaring created the temp variable and nil check in condition - val condBlock = guardIf.condition.isBlock.l - condBlock should not be empty - condBlock.head.local.name.l.exists(_.startsWith("")) shouldBe true - - // Check that b local is in the then block - val thenBlock = guardIf.whenTrue.isBlock.l - thenBlock should not be empty - val List(localB) = thenBlock.head.local.nameExact("b").l + val List(condBlock) = guardIf.condition.isBlock.l + condBlock.local.name.l.exists(_.startsWith("")) shouldBe true + + // Check that b local is in the then block along with code that follows the guard + val List(thenBlock) = guardIf.whenTrue.isBlock.l + val List(localB) = thenBlock.local.nameExact("b").l localB.typeFullName shouldBe "ANY" + + // Verify the print(b) call is also in the then block (code following guard) + val List(printCall) = thenBlock.astChildren.isCall.nameExact("print").l + val List(bArg) = printCall.argument.isIdentifier.l + bArg.name shouldBe "b" } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala index 820f1308248e..fc44ae64612f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala @@ -19,23 +19,21 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: desugared to { 0 = optionalValue; 0 != nil } + // Condition: desugared to { let 0; (0 = optionalValue) != nil } val List(condBlock) = ifNode.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l tmpLocal.name should startWith("") val tmpName = tmpLocal.name - val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"($tmpName = optionalValue) != nil" + + val List(condAssign) = condCheck.argument.assignment.l condAssign.code shouldBe s"$tmpName = optionalValue" condAssign.argument(1).code shouldBe tmpName condAssign.argument(2).code shouldBe "optionalValue" - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - condCheck.code shouldBe s"$tmpName != nil" - condCheck.argument(1).code shouldBe tmpName - condCheck.argument(2).code shouldBe "nil" - // Then block: { let baz = 0; print(baz) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -192,23 +190,21 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: desugared to { 0 = iterator.next(); 0 != nil } + // Condition: desugared to { let 0; (0 = iterator.next()) != nil } val List(condBlock) = whileNode.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l tmpLocal.name should startWith("") val tmpName = tmpLocal.name - val List(condAssign) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"($tmpName = iterator.next()) != nil" + + val List(condAssign) = condCheck.argument.assignment.l condAssign.code shouldBe s"$tmpName = iterator.next()" condAssign.argument(1).code shouldBe tmpName condAssign.argument(2).code shouldBe "iterator.next()" - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - condCheck.code shouldBe s"$tmpName != nil" - condCheck.argument(1).code shouldBe tmpName - condCheck.argument(2).code shouldBe "nil" - // Loop body: { let item = 0; print(item) } val List(bodyBlock) = whileNode.whenTrue.isBlock.l @@ -265,7 +261,7 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { 0 = Optional(1); 1 = Optional(2); 0 != nil && 1 != nil } + // Condition: { let 0; let 1; (0 = Optional(1)) != nil && (1 = Optional(2)) != nil } val List(condBlock) = ifNode.condition.isBlock.l val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l @@ -274,14 +270,14 @@ class StatementTests extends SwiftSrc2CpgSuite { val tmp1Name = tmp1Local.name val tmp2Name = tmp2Local.name - val List(assign1, assign2) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l - assign1.code shouldBe s"$tmp1Name = Optional(1)" - assign2.code shouldBe s"$tmp2Name = Optional(2)" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"$tmp1Name != nil" - check2.code shouldBe s"$tmp2Name != nil" + check1.code shouldBe s"($tmp1Name = Optional(1)) != nil" + check2.code shouldBe s"($tmp2Name = Optional(2)) != nil" + + val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l + assign1.code shouldBe s"$tmp1Name = Optional(1)" + assign2.code shouldBe s"$tmp2Name = Optional(2)" // Then block: { a = 0; b = 1; print(a, b) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -306,20 +302,20 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { 0 = Optional(1); 0 != nil && opt2 != nil } + // Condition: { let 0; (0 = Optional(1)) != nil && opt2 != nil } val List(condBlock) = ifNode.condition.isBlock.l val List(tmp1Local) = condBlock.astChildren.isLocal.l val tmp1Name = tmp1Local.name - val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l - assign1.code shouldBe s"$tmp1Name = Optional(1)" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"$tmp1Name != nil" + check1.code shouldBe s"($tmp1Name = Optional(1)) != nil" check2.code shouldBe "opt2 != nil" + val List(assign1) = check1.argument.assignment.l + assign1.code shouldBe s"$tmp1Name = Optional(1)" + // Then block: { a = 0; print(a, opt2) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -341,21 +337,21 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: { 0 = iterator1.next(); 1 = iterator2.next(); 0 != nil && 1 != nil } + // Condition: { let 0; let 1; (0 = iterator1.next()) != nil && (1 = iterator2.next()) != nil } val List(condBlock) = whileNode.condition.isBlock.l val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l val tmp1Name = tmp1Local.name val tmp2Name = tmp2Local.name - val List(assign1, assign2) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l - assign1.code shouldBe s"$tmp1Name = iterator1.next()" - assign2.code shouldBe s"$tmp2Name = iterator2.next()" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"$tmp1Name != nil" - check2.code shouldBe s"$tmp2Name != nil" + check1.code shouldBe s"($tmp1Name = iterator1.next()) != nil" + check2.code shouldBe s"($tmp2Name = iterator2.next()) != nil" + + val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l + assign1.code shouldBe s"$tmp1Name = iterator1.next()" + assign2.code shouldBe s"$tmp2Name = iterator2.next()" // Loop body: { a = 0; b = 1; print(a, b) } val List(bodyBlock) = whileNode.whenTrue.isBlock.l @@ -380,22 +376,20 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { 0 = foo(); 0 != nil } (tuple pattern excluded from condition) + // Condition: { let 0; (0 = foo()) != nil } (tuple pattern excluded from condition) val List(condBlock) = ifNode.condition.isBlock.l val List(tmp1Local) = condBlock.astChildren.isLocal.l val tmp1Name = tmp1Local.name - val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp1Name = foo()) != nil" + + val List(assign1) = check1.argument.assignment.l assign1.code shouldBe s"$tmp1Name = foo()" assign1.argument(1).code shouldBe tmp1Name assign1.argument(2).code shouldBe "foo()" - val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"$tmp1Name != nil" - check1.argument(1).code shouldBe tmp1Name - check1.argument(2).code shouldBe "nil" - // Then block: { a = 0; let (b, c) = bar(); print(a, b, c) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -427,22 +421,20 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: { 0 = foo(); 0 != nil } (tuple pattern excluded from condition) + // Condition: { let 0; (0 = foo()) != nil } (tuple pattern excluded from condition) val List(condBlock) = whileNode.condition.isBlock.l val List(tmp1Local) = condBlock.astChildren.isLocal.l val tmp1Name = tmp1Local.name - val List(assign1) = condBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp1Name = foo()) != nil" + + val List(assign1) = check1.argument.assignment.l assign1.code shouldBe s"$tmp1Name = foo()" assign1.argument(1).code shouldBe tmp1Name assign1.argument(2).code shouldBe "foo()" - val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"$tmp1Name != nil" - check1.argument(1).code shouldBe tmp1Name - check1.argument(2).code shouldBe "nil" - // Loop body: { a = 0; let (b, c) = bar(); print(a, b, c) } val List(bodyBlock) = whileNode.whenTrue.isBlock.l From aad1f85b84f1f1142e4a99cd11bfe6dfe45e9003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Thu, 28 May 2026 19:34:54 +0200 Subject: [PATCH 3/5] more fixes / test cleanups --- .../astcreation/AstCreatorHelper.scala | 3 - .../passes/ast/AvailabilityQueryTests.scala | 1 - .../swiftsrc2cpg/passes/ast/GuardTests.scala | 322 +++++++++++++----- .../passes/ast/GuardTopLevelTests.scala | 44 ++- 4 files changed, 274 insertions(+), 96 deletions(-) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index b1ce167719c2..dba67f865c4a 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -682,7 +682,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As info.tmpName.foreach { tmpName => val tmpLocalNode = localNode(info.binding, tmpName, tmpName, Defines.Any).order(0) diffGraph.addEdge(condBlockNode, tmpLocalNode, EdgeTypes.AST) - scope.addVariable(tmpName, tmpLocalNode, Defines.Any, VariableScopeManager.ScopeType.BlockScope) } } @@ -699,8 +698,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As s"$tmpName = ${code(info.binding.initializer.get.value)}" ) - val tmpIdentForCheck = identifierNode(info.binding, tmpName, tmpName, Defines.Any) - scope.addVariableReference(tmpName, tmpIdentForCheck, Defines.Any, EvaluationStrategies.BY_REFERENCE) val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) val checkCallNode = createStaticCallNode( info.binding, diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index 3c9a4b345d9c..f49020ff987a 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -4,7 +4,6 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class AvailabilityQueryTests extends SwiftSrc2CpgSuite { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index d448e49712eb..e3d2ff7a204e 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -19,12 +19,11 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("noConditionNoElse").block.l val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 1 guardIf.code shouldBe "guard {} else {}" guardIf.controlStructureType shouldBe ControlStructureTypes.IF guardIf.condition.code.l shouldBe List("0") - guardIf.whenTrue.astChildren.code.l shouldBe empty - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe empty + guardIf.whenTrue.astChildren shouldBe empty + methodBlock.astChildren.isControlStructure.whenFalse.astChildren shouldBe empty } "testGuard2" in { @@ -37,16 +36,26 @@ class GuardTests extends SwiftSrc2CpgSuite { | } | print(i) | i = i + 1 - |} + |} |""".stripMargin) val List(whileBlock) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).whenTrue.l val List(guardIf) = whileBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 1 - guardIf.code should startWith("guard i % 2 == 0") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("i % 2 == 0") - guardIf.whenTrue.astChildren.code.l shouldBe List("print(i)", "i = i + 1") - whileBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List("i = i + 1", "continue") + + val List(condition) = guardIf.condition.l + condition.code shouldBe "i % 2 == 0" + + val List(thenPrint, thenAssign) = guardIf.whenTrue.astChildren.isCall.l + thenPrint.code shouldBe "print(i)" + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe "i = i + 1" + + val List(elseAssign) = guardIf.whenFalse.astChildren.isCall.l + elseAssign.name shouldBe Operators.assignment + elseAssign.code shouldBe "i = i + 1" + + val List(elseContinue) = guardIf.whenFalse.astChildren.isControlStructure.l + elseContinue.code shouldBe "continue" } "testGuard3" in { @@ -62,18 +71,22 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkOddEven").block.l val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var number = 24" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard number % 2 == 0") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("number % 2 == 0") - guardIf.whenTrue.code.l shouldBe List("print(\"Even Number\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Odd Number\")", - "return" - ) + + val List(condition) = guardIf.condition.l + condition.code shouldBe "number % 2 == 0" + + val List(thenPrint) = guardIf.whenTrue.isCall.l + thenPrint.code shouldBe "print(\"Even Number\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.code shouldBe "print(\"Odd Number\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard4" in { @@ -89,18 +102,23 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkJobEligibility").block.l val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var age = 33" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard age >= 18, age <= 40") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.astChildren.code.l shouldBe List("age >= 18", "age <= 40") - guardIf.whenTrue.code.l shouldBe List("print(\"You are eligible for this job\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Not Eligible for Job\")", - "return" - ) + + val List(cond1, cond2) = guardIf.condition.astChildren.l + cond1.code shouldBe "age >= 18" + cond2.code shouldBe "age <= 40" + + val List(thenPrint) = guardIf.whenTrue.isCall.l + thenPrint.code shouldBe "print(\"You are eligible for this job\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.code shouldBe "print(\"Not Eligible for Job\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard5" in { @@ -116,23 +134,50 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkAge").block.l // After desugaring: age in method block, 0 in condition block, myAge in then block - methodBlock.local.name.l should contain("age") + // TODO(BUG): local is incorrectly appearing as direct child of method block + // It should only be in the condition block. For now, filter it out. + val List(ageLocal) = methodBlock.astChildren.isLocal.filterNot(_.name.startsWith("")).l + ageLocal.name shouldBe "age" + val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var age: Int? = 22" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard let myAge = age") guardIf.controlStructureType shouldBe ControlStructureTypes.IF + // Condition is desugared to block with temp assignment and nil check - val condBlock = guardIf.condition.isBlock.l - condBlock should not be empty - condBlock.head.local.name.l.exists(_.startsWith("")) shouldBe true + val List(condBlock) = guardIf.condition.isBlock.l + val List(tmpLocal) = condBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + tmpName should startWith("") + + val List(nilCheck) = condBlock.astChildren.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = age) != nil" + + val List(assignment) = nilCheck.argument.isCall.l + assignment.name shouldBe Operators.assignment + assignment.code shouldBe s"$tmpName = age" + + val List(nilLit) = nilCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" + // Then block should have myAge local - val thenBlock = guardIf.whenTrue.isBlock.l - thenBlock should not be empty - thenBlock.head.local.name.l should contain("myAge") - guardIf.whenFalse.astChildren.code.l shouldBe List("print(\"Age is undefined\")", "return") + val List(thenBlock) = guardIf.whenTrue.isBlock.l + val List(myAgeLocal) = thenBlock.astChildren.isLocal.l + myAgeLocal.name shouldBe "myAge" + + val List(myAgeAssign, printCall) = thenBlock.astChildren.isCall.l + myAgeAssign.name shouldBe Operators.assignment + myAgeAssign.code shouldBe s"myAge = $tmpName" + printCall.code shouldBe "print(\"My age is \\(myAge)\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.name shouldBe "print" + elsePrint.code shouldBe "print(\"Age is undefined\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard6" in { @@ -154,23 +199,38 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("multipleLinear").block.l val List(callA, callB) = methodBlock.astChildren.isCall.l - callA.order shouldBe 1 callA.code shouldBe "var a = true" - callB.order shouldBe 2 callB.code shouldBe "var b = true" + val List(guardIfA) = methodBlock.astChildren.isControlStructure.l - guardIfA.order shouldBe 3 - guardIfA.code should startWith("guard a else") guardIfA.controlStructureType shouldBe ControlStructureTypes.IF - guardIfA.condition.code.l shouldBe List("a") - guardIfA.whenFalse.astChildren.code.l shouldBe List("print(\"else a\")", "return") - guardIfA.whenTrue.astChildren.isCall.code.l shouldBe List("print(\"a\")") + + val List(condA) = guardIfA.condition.l + condA.code shouldBe "a" + + val List(elsePrintA) = guardIfA.whenFalse.astChildren.isCall.l + elsePrintA.code shouldBe "print(\"else a\")" + + val List(elseReturnA) = guardIfA.whenFalse.astChildren.isReturn.l + elseReturnA.code shouldBe "return" + + val List(thenPrintA) = guardIfA.whenTrue.astChildren.isCall.l + thenPrintA.code shouldBe "print(\"a\")" + val List(guardIfB) = guardIfA.whenTrue.astChildren.isControlStructure.l - guardIfB.code should startWith("guard b else") guardIfB.controlStructureType shouldBe ControlStructureTypes.IF - guardIfB.condition.code.l shouldBe List("b") - guardIfB.whenTrue.code.l shouldBe List("print(\"b\")") - guardIfB.whenFalse.astChildren.code.l shouldBe List("print(\"else b\")", "return") + + val List(condB) = guardIfB.condition.l + condB.code shouldBe "b" + + val List(thenPrintB) = guardIfB.whenTrue.isCall.l + thenPrintB.code shouldBe "print(\"b\")" + + val List(elsePrintB) = guardIfB.whenFalse.astChildren.isCall.l + elsePrintB.code shouldBe "print(\"else b\")" + + val List(elseReturnB) = guardIfB.whenFalse.astChildren.isReturn.l + elseReturnB.code shouldBe "return" } "testGuard7" in { @@ -192,23 +252,38 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("multipleNested").block.l val List(callA, callB) = methodBlock.astChildren.isCall.l - callA.order shouldBe 1 callA.code shouldBe "var a = true" - callB.order shouldBe 2 callB.code shouldBe "var b = true" + val List(guardIfA) = methodBlock.astChildren.isControlStructure.l - guardIfA.order shouldBe 3 - guardIfA.code should startWith("guard a else") guardIfA.controlStructureType shouldBe ControlStructureTypes.IF - guardIfA.condition.code.l shouldBe List("a") - guardIfA.whenFalse.astChildren.isCall.code.l shouldBe List("print(\"else a\")") - guardIfA.whenTrue.isCall.code.l shouldBe List("print(\"a\")") + + val List(condA) = guardIfA.condition.l + condA.code shouldBe "a" + + val List(thenPrintA) = guardIfA.whenTrue.isCall.l + thenPrintA.code shouldBe "print(\"a\")" + + val List(elsePrintA) = guardIfA.whenFalse.astChildren.isCall.l + elsePrintA.code shouldBe "print(\"else a\")" + val List(guardIfB) = guardIfA.whenFalse.astChildren.isControlStructure.l - guardIfB.code should startWith("guard b else") guardIfB.controlStructureType shouldBe ControlStructureTypes.IF - guardIfB.condition.code.l shouldBe List("b") - guardIfB.whenTrue.astChildren.code.l shouldBe List("print(\"b\")", "return") - guardIfB.whenFalse.astChildren.code.l shouldBe List("print(\"else b\")", "return") + + val List(condB) = guardIfB.condition.l + condB.code shouldBe "b" + + val List(thenPrintB) = guardIfB.whenTrue.astChildren.isCall.l + thenPrintB.code shouldBe "print(\"b\")" + + val List(thenReturnB) = guardIfB.whenTrue.astChildren.isReturn.l + thenReturnB.code shouldBe "return" + + val List(nestedElsePrintB) = guardIfB.whenFalse.astChildren.isCall.l + nestedElsePrintB.code shouldBe "print(\"else b\")" + + val List(nestedElseReturnB) = guardIfB.whenFalse.astChildren.isReturn.l + nestedElseReturnB.code shouldBe "return" } "testGuardLet" in { @@ -227,27 +302,41 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(condBlock) = guardIf.condition.isBlock.l val List(tmpLocal) = condBlock.astChildren.isLocal.l - tmpLocal.name should startWith("") - val tmpName = tmpLocal.name + val tmpName = tmpLocal.name + tmpName should startWith("") - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + val List(condCheck) = condBlock.astChildren.isCall.l + condCheck.name shouldBe Operators.notEquals condCheck.code shouldBe s"($tmpName = optionalValue) != nil" - val List(condAssign) = condCheck.argument.isCall.nameExact(Operators.assignment).l + val List(condAssign) = condCheck.argument.isCall.l + condAssign.name shouldBe Operators.assignment condAssign.code shouldBe s"$tmpName = optionalValue" - condAssign.argument(1).code shouldBe tmpName - condAssign.argument(2).code shouldBe "optionalValue" - // Then block: { let value = 0 } + val List(tmpArg, optValueArg) = condAssign.argument.l + tmpArg.code shouldBe tmpName + optValueArg.code shouldBe "optionalValue" + + val List(nilLit) = condCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" + + // Then block: { let value = 0; print(value) } val List(thenBlock) = guardIf.whenTrue.isBlock.l val List(valueLocal) = thenBlock.astChildren.isLocal.l valueLocal.name shouldBe "value" - val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment thenAssign.code shouldBe s"value = $tmpName" - thenAssign.argument(1).code shouldBe "value" - thenAssign.argument(2).code shouldBe tmpName + printCall.code shouldBe "print(value)" + + val List(valueArg, tmpRefArg) = thenAssign.argument.l + valueArg.code shouldBe "value" + tmpRefArg.code shouldBe tmpName + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" } "testGuardLetWithoutInitializer" in { @@ -272,7 +361,7 @@ class GuardTests extends SwiftSrc2CpgSuite { // Then branch: print(optionalValue) but no new local or assignment for optionalValue inside(guardIf.whenTrue.l) { case List(thenBlock) => - thenBlock.astChildren.isLocal.nameExact("optionalValue").l shouldBe empty + thenBlock.astChildren.isLocal.nameExact("optionalValue") shouldBe empty val assignments = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l assignments.filter(_.argument(1).code == "optionalValue") shouldBe empty } @@ -297,19 +386,45 @@ class GuardTests extends SwiftSrc2CpgSuite { tmp0.name shouldBe "0" tmp1.name shouldBe "1" - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l - val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd + + val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.l + tmp0Check.name shouldBe Operators.notEquals tmp0Check.code shouldBe s"(${tmp0.name} = foo()) != nil" + tmp1Check.name shouldBe Operators.notEquals tmp1Check.code shouldBe s"(${tmp1.name} = bar()) != nil" - // Then block: { let a = 0; let b = 1 } + val List(tmp0Assign) = tmp0Check.argument.isCall.l + tmp0Assign.name shouldBe Operators.assignment + tmp0Assign.code shouldBe s"${tmp0.name} = foo()" + + val List(tmp0Nil) = tmp0Check.argument.isLiteral.l + tmp0Nil.code shouldBe "nil" + + val List(tmp1Assign) = tmp1Check.argument.isCall.l + tmp1Assign.name shouldBe Operators.assignment + tmp1Assign.code shouldBe s"${tmp1.name} = bar()" + + val List(tmp1Nil) = tmp1Check.argument.isLiteral.l + tmp1Nil.code shouldBe "nil" + + // Then block: { let a = 0; let b = 1; print(a, b) } val List(thenBlock) = guardIf.whenTrue.isBlock.l - thenBlock.astChildren.isLocal.name.sorted shouldBe List("a", "b") + val List(aLocal, bLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" - val List(aAssignment, bAssignment) = thenBlock.astChildren.assignment.l + val List(aAssignment, bAssignment, printCall) = thenBlock.astChildren.isCall.l + aAssignment.name shouldBe Operators.assignment aAssignment.code shouldBe s"a = ${tmp0.name}" + bAssignment.name shouldBe Operators.assignment bAssignment.code shouldBe s"b = ${tmp1.name}" + printCall.code shouldBe "print(a, b)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" } "testGuardLetMixedWithAndWithoutInitializer" in { @@ -331,23 +446,42 @@ class GuardTests extends SwiftSrc2CpgSuite { val tmpName = tmpLocal.name tmpName should startWith("") - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l - val List(tmpCheck) = andCheck.arguments(1).isCall.nameExact(Operators.notEquals).l - val List(existingCheck) = andCheck.arguments(2).isCall.nameExact(Operators.notEquals).l + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd + + val List(tmpCheck, existingCheck) = andCheck.argument.isCall.l + tmpCheck.name shouldBe Operators.notEquals tmpCheck.code shouldBe s"($tmpName = foo()) != nil" + existingCheck.name shouldBe Operators.notEquals existingCheck.code shouldBe "existing != nil" - val List(tmpAssign) = tmpCheck.argument.isCall.nameExact(Operators.assignment).l + val List(tmpAssign) = tmpCheck.argument.isCall.l + tmpAssign.name shouldBe Operators.assignment tmpAssign.code shouldBe s"$tmpName = foo()" - // Then block: { let a = 0 } (no assignment for 'existing') + val List(tmpNil) = tmpCheck.argument.isLiteral.l + tmpNil.code shouldBe "nil" + + val List(existingIdent) = existingCheck.argument.isIdentifier.l + existingIdent.name shouldBe "existing" + existingIdent.code shouldBe "existing" + + val List(existingNil) = existingCheck.argument.isLiteral.l + existingNil.code shouldBe "nil" + + // Then block: { let a = 0; print(a, existing) } (no assignment for 'existing') val List(thenBlock) = guardIf.whenTrue.isBlock.l val List(aLocal) = thenBlock.astChildren.isLocal.l aLocal.name shouldBe "a" - val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment thenAssign.code shouldBe s"a = $tmpName" + printCall.code shouldBe "print(a, existing)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" } "testGuardLetWithOtherConditions" in { @@ -367,16 +501,38 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(tmpLocal) = condBlock.astChildren.isLocal.l val tmpName = tmpLocal.name + tmpName should startWith("") + + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l inside(andCheck.argument.l) { case List(nilCheck: Call, flag: Identifier) => + nilCheck.name shouldBe Operators.notEquals nilCheck.code shouldBe s"($tmpName = foo()) != nil" flag.name shouldBe "flag" + flag.code shouldBe "flag" case List(flag: Identifier, nilCheck: Call) => + nilCheck.name shouldBe Operators.notEquals nilCheck.code shouldBe s"($tmpName = foo()) != nil" flag.name shouldBe "flag" + flag.code shouldBe "flag" } + + // Then block: { let a = 0; print(a) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe s"a = $tmpName" + printCall.name shouldBe "print" + printCall.code shouldBe "print(a)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala index b3cd67e47ca2..e26b0396b830 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala @@ -5,6 +5,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class GuardTopLevelTests extends SwiftSrc2CpgSuite { @@ -17,26 +18,51 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { |print(b) |""".stripMargin) val List(globalBlock) = cpg.method.nameExact("").block.l - val List(localA) = globalBlock.local.nameExact("a").l + val List(localA) = globalBlock.local.l + localA.name shouldBe "a" localA.typeFullName shouldBe "Swift.Int" // After desugaring, b is in the guard's then block, not the global block - val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - guardIf.code should startWith("guard let b = a else") + val List(guardIf: ControlStructure) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + guardIf.code shouldBe "guard let b = a else {}" + guardIf.controlStructureType shouldBe ControlStructureTypes.IF // Check that desugaring created the temp variable and nil check in condition val List(condBlock) = guardIf.condition.isBlock.l - condBlock.local.name.l.exists(_.startsWith("")) shouldBe true + val List(tmpLocal) = condBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + tmpName should startWith("") + + val List(nilCheck) = condBlock.astChildren.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = a) != nil" + + val List(assignment) = nilCheck.argument.isCall.l + assignment.name shouldBe Operators.assignment + assignment.code shouldBe s"$tmpName = a" + + val List(tmpArg, aArg) = assignment.argument.l + tmpArg.code shouldBe tmpName + aArg.code shouldBe "a" + + val List(nilLit) = nilCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" // Check that b local is in the then block along with code that follows the guard val List(thenBlock) = guardIf.whenTrue.isBlock.l - val List(localB) = thenBlock.local.nameExact("b").l - localB.typeFullName shouldBe "ANY" + val List(localB) = thenBlock.local.l + localB.name shouldBe "b" // Verify the print(b) call is also in the then block (code following guard) - val List(printCall) = thenBlock.astChildren.isCall.nameExact("print").l - val List(bArg) = printCall.argument.isIdentifier.l - bArg.name shouldBe "b" + val List(bAssign, printCall) = thenBlock.astChildren.isCall.l + bAssign.name shouldBe Operators.assignment + bAssign.code shouldBe s"b = $tmpName" + printCall.code shouldBe "print(b)" + + val List(bArg) = printCall.argument.l + bArg.code shouldBe "b" + + guardIf.whenFalse.astChildren shouldBe empty } } From ee5bf5f772a7c40611a0d8238d70609c951ad193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Fri, 29 May 2026 10:47:42 +0200 Subject: [PATCH 4/5] fixed local creation scoping --- .../astcreation/AstCreatorHelper.scala | 30 ++-- .../astcreation/AstForExprSyntaxCreator.scala | 6 +- .../astcreation/AstForStmtSyntaxCreator.scala | 23 +-- .../passes/ast/AvailabilityQueryTests.scala | 15 +- .../swiftsrc2cpg/passes/ast/GuardTests.scala | 72 ++++----- .../passes/ast/GuardTopLevelTests.scala | 9 +- .../passes/ast/StatementTests.scala | 138 ++++++++++-------- 7 files changed, 155 insertions(+), 138 deletions(-) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index dba67f865c4a..3a00ebbd497f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -155,31 +155,38 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As // Apply optional binding desugaring for guard let // Create the block that will hold the unwrapped variables (blockNode argument is only used for location info) - val thenBlockNode = if (elementsAfterGuard.nonEmpty) blockNode(elementsAfterGuard.head) else blockNode(guardStmt) - scope.pushNewBlockScope(thenBlockNode) - localAstParentStack.push(thenBlockNode) + val thenBlockNode = + if (elementsAfterGuard.nonEmpty) blockNode(elementsAfterGuard.head) else blockNode(guardStmt) val (conditionAst, unwrapAsts) = handleOptionalBindingConditions( guardStmt.conditions.children, onAllSimple = simpleBindings => { val bindingInfos = collectBindingInfos(simpleBindings) val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) - val unwraps = buildUnwrapAssignments(bindingInfos) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) (condAst, unwraps) }, onMixed = (simpleBindings, tupleBindings) => { val bindingInfos = collectBindingInfos(simpleBindings) val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) - val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) (condAst, unwraps) }, onPartial = (simpleBindings, tupleBindings, otherConditions) => { val bindingInfos = collectBindingInfos(simpleBindings) val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos, otherConditions) - val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) (condAst, unwraps) }, onStandard = () => { + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) val condAst = astForNode(guardStmt.conditions) (condAst, List.empty) } @@ -674,17 +681,18 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val hasAnyInitializer = bindingInfos.exists(_.tmpName.isDefined) if (hasAnyInitializer) { - val condBlockNode = blockNode(node) - scope.pushNewBlockScope(condBlockNode) - localAstParentStack.push(condBlockNode) - bindingInfos.foreach { info => info.tmpName.foreach { tmpName => val tmpLocalNode = localNode(info.binding, tmpName, tmpName, Defines.Any).order(0) - diffGraph.addEdge(condBlockNode, tmpLocalNode, EdgeTypes.AST) + diffGraph.addEdge(localAstParentStack.head, tmpLocalNode, EdgeTypes.AST) + scope.addVariable(tmpName, tmpLocalNode, Defines.Any, VariableScopeManager.ScopeType.BlockScope) } } + val condBlockNode = blockNode(node) + scope.pushNewBlockScope(condBlockNode) + localAstParentStack.push(condBlockNode) + val nilCheckAsts = bindingInfos.map { info => info.tmpName match { case Some(tmpName) => diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index cc39a977a048..89114e49f7bd 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -529,19 +529,19 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `if let baz = foo() { body }` into: * - * Condition: { let 0; (0 = foo()) != nil } + * Condition: { (0 = foo()) != nil } * * Then block: { let baz = 0; body } * * For multiple bindings `if let a = foo(), let b = bar() { body }`: * - * Condition: { let 0; let 1; (0 = foo()) != nil && (1 = bar()) != nil } + * Condition: { (0 = foo()) != nil && (1 = bar()) != nil } * * Then block: { a = 0; b = 1; body } * * For mixed cases with/without initializers `if let a = foo(), let b { body }`: * - * Condition: { let 0; (0 = foo()) != nil && b != nil } + * Condition: { (0 = foo()) != nil && b != nil } * * Then block: { a = 0; body } */ diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala index 07413bf8e3d9..debc4c1643d5 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala @@ -556,30 +556,15 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForGuardStmtSyntax(node: GuardStmtSyntax): Ast = { - val code = this.code(node) - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) - - handleOptionalBindingConditions( - node.conditions.children, - onAllSimple = simpleBindings => astForGuardLetStmtSyntax(node, ifNode, simpleBindings), - onMixed = - (simpleBindings, tupleBindings) => astForGuardLetStmtSyntaxMixed(node, ifNode, simpleBindings, tupleBindings), - onPartial = (simpleBindings, tupleBindings, otherConditions) => - astForGuardLetStmtSyntaxPartial(node, ifNode, simpleBindings, tupleBindings, otherConditions), - onStandard = () => { - val conditionAst = astForNode(node.conditions) - val thenAst = blockAst(blockNode(node), List.empty) - val elseAst = astForNode(node.body) - ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) - } - ) + // This is already handled in AstCreatorHelper.astsForBlockElements + Ast() } /** Handles Swift optional binding (guard-let) constructs. * * De-sugars `guard let x = foo() else { exit }` into: * - * Condition: { let 0 = foo(); 0 != nil } + * Condition: { 0 = foo(); 0 != nil } * * Then block: { let x = 0 } * @@ -759,7 +744,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `while let item = iterator.next() { body }` into: * - * Condition: { let 0 = iterator.next(); 0 != nil } + * Condition: { 0 = iterator.next(); 0 != nil } * * Loop body: { let item = 0; body } * diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index f49020ff987a..eabf474adc38 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -39,15 +39,18 @@ class AvailabilityQueryTests extends SwiftSrc2CpgSuite { } "testAvailabilityQuery5b" in { - val cpg = code("if let _ = Optional(5), #unavailable(OSX 10.52, *) {}") + val cpg = code("if let _ = Optional(5), #unavailable(OSX 10.52, *) {}") + val List(methodBlock) = cpg.method.nameExact("").block.l + // After desugaring: 0 in global method block + val List(tmpLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + val List(ifControlStructure) = cpg.controlStructure.isIf.l ifControlStructure.whenTrue.astChildren shouldBe empty - // Condition is desugared: { let 0; (0 = Optional(5)) != nil && #unavailable(...) } - val List(condBlock) = ifControlStructure.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - + // Condition is desugared: { (0 = Optional(5)) != nil && #unavailable(...) } + val List(condBlock) = ifControlStructure.condition.isBlock.l val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(nilCheck, unavailable) = andCheck.argument.isCall.l nilCheck.name shouldBe Operators.notEquals diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index e3d2ff7a204e..0529d7ce1a00 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -133,11 +133,11 @@ class GuardTests extends SwiftSrc2CpgSuite { |} |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkAge").block.l - // After desugaring: age in method block, 0 in condition block, myAge in then block - // TODO(BUG): local is incorrectly appearing as direct child of method block - // It should only be in the condition block. For now, filter it out. - val List(ageLocal) = methodBlock.astChildren.isLocal.filterNot(_.name.startsWith("")).l + // After desugaring: age, 0 in method block, myAge in then block + val List(tmpLocal, ageLocal) = methodBlock.astChildren.isLocal.l ageLocal.name shouldBe "age" + val tmpName = tmpLocal.name + tmpName should startWith("") val List(call) = methodBlock.astChildren.isCall.l call.code shouldBe "var age: Int? = 22" @@ -147,11 +147,7 @@ class GuardTests extends SwiftSrc2CpgSuite { // Condition is desugared to block with temp assignment and nil check val List(condBlock) = guardIf.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - tmpName should startWith("") - - val List(nilCheck) = condBlock.astChildren.isCall.l + val List(nilCheck) = condBlock.astChildren.isCall.l nilCheck.name shouldBe Operators.notEquals nilCheck.code shouldBe s"($tmpName = age) != nil" @@ -295,16 +291,16 @@ class GuardTests extends SwiftSrc2CpgSuite { | print(value) |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + tmpName should startWith("") val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: desugared to { let 0; (0 = optionalValue) != nil } + // Condition: desugared to { (0 = optionalValue) != nil } val List(condBlock) = guardIf.condition.isBlock.l - - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - tmpName should startWith("") - val List(condCheck) = condBlock.astChildren.isCall.l condCheck.name shouldBe Operators.notEquals condCheck.code shouldBe s"($tmpName = optionalValue) != nil" @@ -376,35 +372,38 @@ class GuardTests extends SwiftSrc2CpgSuite { | print(a, b) |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + val tmp1Name = tmp1Local.name + tmp1Name shouldBe "1" val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l // Condition: desugared to { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } val List(condBlock) = guardIf.condition.isBlock.l - val List(tmp0, tmp1) = condBlock.astChildren.isLocal.l - tmp0.name shouldBe "0" - tmp1.name shouldBe "1" - val List(andCheck) = condBlock.astChildren.isCall.l andCheck.name shouldBe Operators.logicalAnd val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.l tmp0Check.name shouldBe Operators.notEquals - tmp0Check.code shouldBe s"(${tmp0.name} = foo()) != nil" + tmp0Check.code shouldBe s"($tmp0Name = foo()) != nil" tmp1Check.name shouldBe Operators.notEquals - tmp1Check.code shouldBe s"(${tmp1.name} = bar()) != nil" + tmp1Check.code shouldBe s"($tmp1Name = bar()) != nil" val List(tmp0Assign) = tmp0Check.argument.isCall.l tmp0Assign.name shouldBe Operators.assignment - tmp0Assign.code shouldBe s"${tmp0.name} = foo()" + tmp0Assign.code shouldBe s"$tmp0Name = foo()" val List(tmp0Nil) = tmp0Check.argument.isLiteral.l tmp0Nil.code shouldBe "nil" val List(tmp1Assign) = tmp1Check.argument.isCall.l tmp1Assign.name shouldBe Operators.assignment - tmp1Assign.code shouldBe s"${tmp1.name} = bar()" + tmp1Assign.code shouldBe s"$tmp1Name = bar()" val List(tmp1Nil) = tmp1Check.argument.isLiteral.l tmp1Nil.code shouldBe "nil" @@ -418,9 +417,9 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(aAssignment, bAssignment, printCall) = thenBlock.astChildren.isCall.l aAssignment.name shouldBe Operators.assignment - aAssignment.code shouldBe s"a = ${tmp0.name}" + aAssignment.code shouldBe s"a = $tmp0Name" bAssignment.name shouldBe Operators.assignment - bAssignment.code shouldBe s"b = ${tmp1.name}" + bAssignment.code shouldBe s"b = $tmp1Name" printCall.code shouldBe "print(a, b)" val List(elseReturn) = guardIf.whenFalse.l @@ -436,16 +435,17 @@ class GuardTests extends SwiftSrc2CpgSuite { | print(a, existing) |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmpName = tmpLocal.name + tmpName shouldBe "0" val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { let 0; (0 = foo()) != nil && existing != nil } + // Condition: { (0 = foo()) != nil && existing != nil } val List(condBlock) = guardIf.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - tmpName should startWith("") - val List(andCheck) = condBlock.astChildren.isCall.l andCheck.name shouldBe Operators.logicalAnd @@ -493,17 +493,17 @@ class GuardTests extends SwiftSrc2CpgSuite { | print(a) |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmpName = tmpLocal.name + tmpName shouldBe "0" val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l // Condition: { 0 = foo(); 0 != nil && flag } val List(condBlock) = guardIf.condition.isBlock.l - - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - tmpName should startWith("") - - val List(andCheck) = condBlock.astChildren.isCall.l + val List(andCheck) = condBlock.astChildren.isCall.l andCheck.name shouldBe Operators.logicalAnd inside(andCheck.argument.l) { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala index e26b0396b830..3073e58457c4 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala @@ -18,7 +18,11 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { |print(b) |""".stripMargin) val List(globalBlock) = cpg.method.nameExact("").block.l - val List(localA) = globalBlock.local.l + + // After desugaring: 0 and `a` in global block + val List(tmpLocal, localA) = globalBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" localA.name shouldBe "a" localA.typeFullName shouldBe "Swift.Int" @@ -29,9 +33,6 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { // Check that desugaring created the temp variable and nil check in condition val List(condBlock) = guardIf.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - val tmpName = tmpLocal.name - tmpName should startWith("") val List(nilCheck) = condBlock.astChildren.isCall.l nilCheck.name shouldBe Operators.notEquals diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala index fc44ae64612f..8cab869386a9 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala @@ -16,16 +16,18 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal, optionalValueLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + optionalValueLocal.name shouldBe "optionalValue" val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: desugared to { let 0; (0 = optionalValue) != nil } + // Condition: desugared to { (0 = optionalValue) != nil } val List(condBlock) = ifNode.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - tmpLocal.name should startWith("") - val tmpName = tmpLocal.name - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l condCheck.code shouldBe s"($tmpName = optionalValue) != nil" @@ -187,16 +189,18 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal, iteratorLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + iteratorLocal.name shouldBe "iterator" val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: desugared to { let 0; (0 = iterator.next()) != nil } + // Condition: desugared to { (0 = iterator.next()) != nil } val List(condBlock) = whileNode.condition.isBlock.l - val List(tmpLocal) = condBlock.astChildren.isLocal.l - tmpLocal.name should startWith("") - val tmpName = tmpLocal.name - val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l condCheck.code shouldBe s"($tmpName = iterator.next()) != nil" @@ -258,26 +262,27 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.local.nameNot("self").l + val tmp0Name = tmp0Local.name + val tmp1Name = tmp1Local.name + tmp0Name shouldBe "0" + tmp1Name shouldBe "1" val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { let 0; let 1; (0 = Optional(1)) != nil && (1 = Optional(2)) != nil } + // Condition: { (0 = Optional(1)) != nil && (1 = Optional(2)) != nil } val List(condBlock) = ifNode.condition.isBlock.l - val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l - tmp1Local.name should startWith("") - tmp2Local.name should startWith("") - val tmp1Name = tmp1Local.name - val tmp2Name = tmp2Local.name - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"($tmp1Name = Optional(1)) != nil" - check2.code shouldBe s"($tmp2Name = Optional(2)) != nil" + check1.code shouldBe s"($tmp0Name = Optional(1)) != nil" + check2.code shouldBe s"($tmp1Name = Optional(2)) != nil" val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l - assign1.code shouldBe s"$tmp1Name = Optional(1)" - assign2.code shouldBe s"$tmp2Name = Optional(2)" + assign1.code shouldBe s"$tmp0Name = Optional(1)" + assign2.code shouldBe s"$tmp1Name = Optional(2)" // Then block: { a = 0; b = 1; print(a, b) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -287,8 +292,8 @@ class StatementTests extends SwiftSrc2CpgSuite { bLocal.name shouldBe "b" val List(unwrapA, unwrapB) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l - unwrapA.code shouldBe s"a = $tmp1Name" - unwrapB.code shouldBe s"b = $tmp2Name" + unwrapA.code shouldBe s"a = $tmp0Name" + unwrapB.code shouldBe s"b = $tmp1Name" } "testIfLetMixed" in { @@ -299,22 +304,24 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmp0Local) = methodBlock.local.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { let 0; (0 = Optional(1)) != nil && opt2 != nil } + // Condition: { (0 = Optional(1)) != nil && opt2 != nil } val List(condBlock) = ifNode.condition.isBlock.l - val List(tmp1Local) = condBlock.astChildren.isLocal.l - val tmp1Name = tmp1Local.name - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"($tmp1Name = Optional(1)) != nil" + check1.code shouldBe s"($tmp0Name = Optional(1)) != nil" check2.code shouldBe "opt2 != nil" val List(assign1) = check1.argument.assignment.l - assign1.code shouldBe s"$tmp1Name = Optional(1)" + assign1.code shouldBe s"$tmp0Name = Optional(1)" // Then block: { a = 0; print(a, opt2) } val List(thenBlock) = ifNode.whenTrue.isBlock.l @@ -323,7 +330,7 @@ class StatementTests extends SwiftSrc2CpgSuite { aLocal.name shouldBe "a" val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l - unwrapA.code shouldBe s"a = $tmp1Name" + unwrapA.code shouldBe s"a = $tmp0Name" } "testWhileLetMultiple" in { @@ -334,24 +341,27 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.local.nameNot("iterator1", "iterator2").l + val tmp0Name = tmp0Local.name + val tmp1Name = tmp1Local.name + tmp0Name shouldBe "0" + tmp1Name shouldBe "1" val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: { let 0; let 1; (0 = iterator1.next()) != nil && (1 = iterator2.next()) != nil } + // Condition: { (0 = iterator1.next()) != nil && (1 = iterator2.next()) != nil } val List(condBlock) = whileNode.condition.isBlock.l - val List(tmp1Local, tmp2Local) = condBlock.astChildren.isLocal.l - val tmp1Name = tmp1Local.name - val tmp2Name = tmp2Local.name - val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"($tmp1Name = iterator1.next()) != nil" - check2.code shouldBe s"($tmp2Name = iterator2.next()) != nil" + check1.code shouldBe s"($tmp0Name = iterator1.next()) != nil" + check2.code shouldBe s"($tmp1Name = iterator2.next()) != nil" val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l - assign1.code shouldBe s"$tmp1Name = iterator1.next()" - assign2.code shouldBe s"$tmp2Name = iterator2.next()" + assign1.code shouldBe s"$tmp0Name = iterator1.next()" + assign2.code shouldBe s"$tmp1Name = iterator2.next()" // Loop body: { a = 0; b = 1; print(a, b) } val List(bodyBlock) = whileNode.whenTrue.isBlock.l @@ -361,8 +371,8 @@ class StatementTests extends SwiftSrc2CpgSuite { bLocal.name shouldBe "b" val List(unwrapA, unwrapB) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).l - unwrapA.code shouldBe s"a = $tmp1Name" - unwrapB.code shouldBe s"b = $tmp2Name" + unwrapA.code shouldBe s"a = $tmp0Name" + unwrapB.code shouldBe s"b = $tmp1Name" } "testIfLetMixedWithTuplePattern" in { @@ -373,21 +383,26 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + // FIXME: because optional tuple bindings are not handled yet the default handling + // (which is wrong) will not create correct locals and the arguments to print end up + // creating locals in the method block. + val List(tmp0Local) = methodBlock.local.nameNot("self", "b", "c").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { let 0; (0 = foo()) != nil } (tuple pattern excluded from condition) + // Condition: { (0 = foo()) != nil } (tuple pattern excluded from condition) val List(condBlock) = ifNode.condition.isBlock.l - val List(tmp1Local) = condBlock.astChildren.isLocal.l - val tmp1Name = tmp1Local.name - val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"($tmp1Name = foo()) != nil" + check1.code shouldBe s"($tmp0Name = foo()) != nil" val List(assign1) = check1.argument.assignment.l - assign1.code shouldBe s"$tmp1Name = foo()" - assign1.argument(1).code shouldBe tmp1Name + assign1.code shouldBe s"$tmp0Name = foo()" + assign1.argument(1).code shouldBe tmp0Name assign1.argument(2).code shouldBe "foo()" // Then block: { a = 0; let (b, c) = bar(); print(a, b, c) } @@ -397,9 +412,9 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(aLocal) = thenBlock.astChildren.isLocal.nameExact("a").l aLocal.name shouldBe "a" - val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp1Name").l + val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp0Name").l unwrapA.argument(1).code shouldBe "a" - unwrapA.argument(2).code shouldBe tmp1Name + unwrapA.argument(2).code shouldBe tmp0Name // Second child: tuple binding block for (b, c) = bar() // The tuple binding creates a nested block with tmp assignment + tuple destructuring @@ -418,21 +433,26 @@ class StatementTests extends SwiftSrc2CpgSuite { | } |} |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + // FIXME: because optional tuple bindings are not handled yet the default handling + // (which is wrong) will not create correct locals and the arguments to print end up + // creating locals in the method block. + val List(tmp0Local) = methodBlock.local.nameNot("self", "b", "c").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l - // Condition: { let 0; (0 = foo()) != nil } (tuple pattern excluded from condition) + // Condition: { (0 = foo()) != nil } (tuple pattern excluded from condition) val List(condBlock) = whileNode.condition.isBlock.l - val List(tmp1Local) = condBlock.astChildren.isLocal.l - val tmp1Name = tmp1Local.name - val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l - check1.code shouldBe s"($tmp1Name = foo()) != nil" + check1.code shouldBe s"($tmp0Name = foo()) != nil" val List(assign1) = check1.argument.assignment.l - assign1.code shouldBe s"$tmp1Name = foo()" - assign1.argument(1).code shouldBe tmp1Name + assign1.code shouldBe s"$tmp0Name = foo()" + assign1.argument(1).code shouldBe tmp0Name assign1.argument(2).code shouldBe "foo()" // Loop body: { a = 0; let (b, c) = bar(); print(a, b, c) } @@ -442,9 +462,9 @@ class StatementTests extends SwiftSrc2CpgSuite { val List(aLocal) = bodyBlock.astChildren.isLocal.nameExact("a").l aLocal.name shouldBe "a" - val List(unwrapA) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp1Name").l + val List(unwrapA) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp0Name").l unwrapA.argument(1).code shouldBe "a" - unwrapA.argument(2).code shouldBe tmp1Name + unwrapA.argument(2).code shouldBe tmp0Name // Second child: tuple binding block for (b, c) = bar() // The tuple binding creates a nested block with tmp assignment + tuple destructuring From 0a39619023a5e68fa015fda0c0f843bfd7a6241c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Fri, 29 May 2026 15:36:27 +0200 Subject: [PATCH 5/5] added test for 3 let --- .../astcreation/AstCreatorHelper.scala | 11 ++-- .../astcreation/AstForExprSyntaxCreator.scala | 4 +- .../astcreation/AstForStmtSyntaxCreator.scala | 20 +++--- .../swiftsrc2cpg/passes/ast/GuardTests.scala | 64 ++++++++++++++++++- 4 files changed, 80 insertions(+), 19 deletions(-) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index 3a00ebbd497f..3bc33608af91 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -193,6 +193,8 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As ) val allThenChildren = unwrapAsts ++ astsForBlockElements(elementsAfterGuard) ++ deferElementsAstsOrdered + + // Closing the scope opened at the handleOptionalBindingConditions handler scope.popScope() localAstParentStack.pop() @@ -651,11 +653,10 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As nilCheckAsts.head } else { nilCheckAsts.reduce { (left, right) => - val leftCode = left.root.map(codeOf).getOrElse("") - val rightCode = right.root.map(codeOf).getOrElse("") - val andCode = s"$leftCode && $rightCode" - val andCallNode = - createStaticCallNode(node, andCode, Operators.logicalAnd, Operators.logicalAnd, Defines.Bool) + val leftCode = left.root.map(codeOf).getOrElse("") + val rightCode = right.root.map(codeOf).getOrElse("") + val andCode = s"($leftCode) && ($rightCode)" + val andCallNode = createStaticCallNode(node, andCode, Operators.logicalAnd, Operators.logicalAnd, Defines.Bool) callAst(andCallNode, List(left, right)) } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index 89114e49f7bd..6a2b86912731 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -564,7 +564,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `if let a = foo(), let (b, c) = bar() { body }` into: * - * Condition: { 0 = foo(); 0 != nil } + * Condition: { (0 = foo()) != nil } * * Then block: { let a = 0; let (b, c) = bar(); body } */ @@ -588,7 +588,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `if let a = foo(), #unavailable(...) { body }` into: * - * Condition: { 0 = foo(); 0 != nil && #unavailable(...) } + * Condition: { ((0 = foo()) != nil) && #unavailable(...) } * * Then block: { let a = 0; body } */ diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala index debc4c1643d5..ffc667163a5d 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala @@ -564,7 +564,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `guard let x = foo() else { exit }` into: * - * Condition: { 0 = foo(); 0 != nil } + * Condition: { (0 = foo()) != nil } * * Then block: { let x = 0 } * @@ -572,13 +572,13 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * For multiple bindings `guard let a = foo(), let b = bar() else { exit }`: * - * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * Condition: { ((0 = foo()) != nil) && ((1 = bar()) != nil) } * * Then block: { a = 0; b = 1 } * * For mixed cases with/without initializers `guard let a = foo(), let b else { exit }`: * - * Condition: { 0 = foo(); 0 != nil && b != nil } + * Condition: { ((0 = foo()) != nil) && (b != nil) } * * Then block: { a = 0 } */ @@ -608,7 +608,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `guard let a = foo(), let (b, c) = bar() else { exit }` into: * - * Condition: { 0 = foo(); 0 != nil } + * Condition: { (0 = foo()) != nil } * * Then block: { let a = 0; let (b, c) = bar() } */ @@ -638,7 +638,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `guard let a = foo(), someCondition else { exit }` into: * - * Condition: { 0 = foo(); 0 != nil && someCondition } + * Condition: { ((0 = foo()) != nil) && someCondition } * * Then block: { let a = 0 } */ @@ -744,19 +744,19 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `while let item = iterator.next() { body }` into: * - * Condition: { 0 = iterator.next(); 0 != nil } + * Condition: { (0 = iterator.next()) != nil } * * Loop body: { let item = 0; body } * * For multiple bindings `while let a = foo(), let b = bar() { body }`: * - * Condition: { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + * Condition: { ((0 = foo()) != nil) && ((1 = bar()) != nil) } * * Loop body: { a = 0; b = 1; body } * * For mixed cases with/without initializers `while let a = foo(), let b { body }`: * - * Condition: { 0 = foo(); 0 != nil && b != nil } + * Condition: { ((0 = foo()) != nil) && (b != nil) } * * Loop body: { a = 0; body } */ @@ -781,7 +781,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `while let a = foo(), let (b, c) = bar() { body }` into: * - * Condition: { 0 = foo(); 0 != nil } + * Condition: { (0 = foo()) != nil } * * Loop body: { let a = 0; let (b, c) = bar(); body } */ @@ -807,7 +807,7 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * * De-sugars `while let a = foo(), someCondition { body }` into: * - * Condition: { 0 = foo(); 0 != nil && someCondition } + * Condition: { ((0 = foo()) != nil) && someCondition } * * Loop body: { let a = 0; body } */ diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index 0529d7ce1a00..213782a02c4f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -382,7 +382,7 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: desugared to { 0 = foo(); 1 = bar(); 0 != nil && 1 != nil } + // Condition: desugared to { ((0 = foo()) != nil) && ((1 = bar()) != nil) } val List(condBlock) = guardIf.condition.isBlock.l val List(andCheck) = condBlock.astChildren.isCall.l @@ -426,6 +426,66 @@ class GuardTests extends SwiftSrc2CpgSuite { elseReturn.code shouldBe "return" } + "testGuardLetThreeBindings" in { + val cpg = code(""" + |func test() { + | guard let a = foo(), let b = bar(), let c = baz() else { + | return + | } + | print(a, b, c) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1, 2 in method block + val List(tmp0Local, tmp1Local, tmp2Local) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + val tmp1Name = tmp1Local.name + tmp1Name shouldBe "1" + val tmp2Name = tmp2Local.name + tmp2Name shouldBe "2" + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { (((0 = foo()) != nil) && ((1 = bar()) != nil)) && ((2 = baz()) != nil) } + // The && operators are left-associated: (check0 && check1) && check2 + val List(condBlock) = guardIf.condition.isBlock.l + + val List(outerAndCheck) = condBlock.astChildren.isCall.l + outerAndCheck.name shouldBe Operators.logicalAnd + outerAndCheck.code shouldBe s"((($tmp0Name = foo()) != nil) && (($tmp1Name = bar()) != nil)) && (($tmp2Name = baz()) != nil)" + + val List(innerAndCheck, tmp2Check) = outerAndCheck.argument.isCall.l + innerAndCheck.name shouldBe Operators.logicalAnd + innerAndCheck.code shouldBe s"(($tmp0Name = foo()) != nil) && (($tmp1Name = bar()) != nil)" + + tmp2Check.name shouldBe Operators.notEquals + tmp2Check.code shouldBe s"($tmp2Name = baz()) != nil" + + val List(tmp0Check, tmp1Check) = innerAndCheck.argument.isCall.l + tmp0Check.name shouldBe Operators.notEquals + tmp0Check.code shouldBe s"($tmp0Name = foo()) != nil" + tmp1Check.name shouldBe Operators.notEquals + tmp1Check.code shouldBe s"($tmp1Name = bar()) != nil" + + // Then block: { let a = 0; let b = 1; let c = 2; print(a, b, c) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal, bLocal, cLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + cLocal.name shouldBe "c" + + val List(aAssignment, bAssignment, cAssignment, printCall) = thenBlock.astChildren.isCall.l + aAssignment.name shouldBe Operators.assignment + aAssignment.code shouldBe s"a = $tmp0Name" + bAssignment.name shouldBe Operators.assignment + bAssignment.code shouldBe s"b = $tmp1Name" + cAssignment.name shouldBe Operators.assignment + cAssignment.code shouldBe s"c = $tmp2Name" + printCall.code shouldBe "print(a, b, c)" + } + "testGuardLetMixedWithAndWithoutInitializer" in { val cpg = code(""" |func test(existing: Int?) { @@ -501,7 +561,7 @@ class GuardTests extends SwiftSrc2CpgSuite { val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - // Condition: { 0 = foo(); 0 != nil && flag } + // Condition: { ((0 = foo()) != nil) && flag } val List(condBlock) = guardIf.condition.isBlock.l val List(andCheck) = condBlock.astChildren.isCall.l andCheck.name shouldBe Operators.logicalAnd